diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..5d86aae87a51390e571dbe0d52aceefa14864f2a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +advanced.png filter=lfs diff=lfs merge=lfs -text +flow.gif filter=lfs diff=lfs merge=lfs -text +library/__pycache__/train_util.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text +publish_to_hf.png filter=lfs diff=lfs merge=lfs -text +sample.png filter=lfs diff=lfs merge=lfs -text +screenshot.png filter=lfs diff=lfs merge=lfs -text +seed.gif filter=lfs diff=lfs merge=lfs -text diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..df0f4804ec94ec47358904347160c5eb8ec4cdd4 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,53 @@ +# Base image with CUDA 12.2 +FROM nvidia/cuda:12.2.2-base-ubuntu22.04 + +# Install pip if not already installed +RUN apt-get update -y && apt-get install -y \ + python3-pip \ + python3-dev \ + git \ + build-essential # Install dependencies for building extensions + +# Define environment variables for UID and GID and local timezone +# ENV PUID=${PUID:-1000} +# ENV PGID=${PGID:-1000} + +# Create a group with the specified GID +# RUN groupadd -g "${PGID}" appuser +# Create a user with the specified UID and GID +# RUN useradd -m -s /bin/sh -u "${PUID}" -g "${PGID}" appuser + +WORKDIR /app + +# Get sd-scripts from kohya-ss and install them +RUN git clone -b sd3 https://github.com/kohya-ss/sd-scripts && \ + cd sd-scripts && \ + pip install --no-cache-dir -r ./requirements.txt + +# Install main application dependencies +COPY ./requirements.txt ./requirements.txt +RUN pip install --no-cache-dir -r ./requirements.txt + +# Install Torch, Torchvision, and Torchaudio for CUDA 12.2 +RUN pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu122/torch_stable.html + +RUN chown -R appuser:appuser /app + +# delete redundant requirements.txt and sd-scripts directory within the container +RUN rm -r ./sd-scripts +RUN rm ./requirements.txt +RUN pip install --force-reinstall -v "triton==3.1.0" +#Run application as non-root +# USER appuser + +# Copy fluxgym application code +COPY . ./fluxgym + +EXPOSE 7860 + +ENV GRADIO_SERVER_NAME="0.0.0.0" + +WORKDIR /app/fluxgym + +# Run fluxgym Python application +CMD ["python3", "./app.py"] diff --git a/Dockerfile.cuda12.4 b/Dockerfile.cuda12.4 new file mode 100644 index 0000000000000000000000000000000000000000..52fc97b84b8c625f25148f69f471601c89a336e7 --- /dev/null +++ b/Dockerfile.cuda12.4 @@ -0,0 +1,53 @@ +# Base image with CUDA 12.4 +FROM nvidia/cuda:12.4.1-base-ubuntu22.04 + +# Install pip if not already installed +RUN apt-get update -y && apt-get install -y \ + python3-pip \ + python3-dev \ + git \ + build-essential # Install dependencies for building extensions + +# Define environment variables for UID and GID and local timezone +ENV PUID=${PUID:-1000} +ENV PGID=${PGID:-1000} + +# Create a group with the specified GID +RUN groupadd -g "${PGID}" appuser +# Create a user with the specified UID and GID +RUN useradd -m -s /bin/sh -u "${PUID}" -g "${PGID}" appuser + +WORKDIR /app + +# Get sd-scripts from kohya-ss and install them +RUN git clone -b sd3 https://github.com/kohya-ss/sd-scripts && \ + cd sd-scripts && \ + pip install --no-cache-dir -r ./requirements.txt + +# Install main application dependencies +COPY ./requirements.txt ./requirements.txt +RUN pip install --no-cache-dir -r ./requirements.txt + +# Install Torch, Torchvision, and Torchaudio for CUDA 12.4 +RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 + +RUN chown -R appuser:appuser /app + +# delete redundant requirements.txt and sd-scripts directory within the container +RUN rm -r ./sd-scripts +RUN rm ./requirements.txt + +#Run application as non-root +USER appuser + +# Copy fluxgym application code +COPY . ./fluxgym + +EXPOSE 7860 + +ENV GRADIO_SERVER_NAME="0.0.0.0" + +WORKDIR /app/fluxgym + +# Run fluxgym Python application +CMD ["python3", "./app.py"] \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..394fb377927a836da8872cf92587d46fea5fc3c2 --- /dev/null +++ b/LICENSE @@ -0,0 +1,7 @@ +Copyright 2024 cocktailpeanut + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f31391dc0c61911f77ceafdb637658f2d8506397 --- /dev/null +++ b/README.md @@ -0,0 +1,54 @@ +# ComfyUI Flux Trainer + +Wrapper for slightly modified kohya's training scripts: https://github.com/kohya-ss/sd-scripts + +Including code from: https://github.com/KohakuBlueleaf/Lycoris + +And https://github.com/LoganBooker/prodigy-plus-schedule-free + +## DISCLAIMER: +I have **very** little previous experience in training anything, Flux is basically first model I've been inspired to learn. Previously I've only trained AnimateDiff Motion Loras, and built similar training nodes for it. + +## DO NOT ASK ME FOR TRAINING ADVICE +I can not emphasize this enough, this repository is not for raising questions related to the training itself, that would be better done to kohya's repo. Even so keep in mind my implementation may have mistakes. + +The default settings aren't necessarily any good, they are just the last (out of many) I've tried and worked for my dataset. + +# THIS IS EXPERIMENTAL +Both these nodes and the underlaying implementation by kohya is work in progress and expected to change. + +# Installation +1. Clone this repo into `custom_nodes` folder. +2. Install dependencies: `pip install -r requirements.txt` + or if you use the portable install, run this in ComfyUI_windows_portable -folder: + + `python_embeded\python.exe -m pip install -r ComfyUI\custom_nodes\ComfyUI-FluxTrainer\requirements.txt` + +In addition torch version 2.4.0 or higher is highly recommended. + +Example workflow for LoRA training can be found in the examples folder, it utilizes additional nodes from: + +https://github.com/kijai/ComfyUI-KJNodes + +And some (optional) debugging nodes from: + +https://github.com/rgthree/rgthree-comfy + +For LoRA training the models need to be the normal fp8 or fp16 versions, also make sure the VAE is the non-diffusers version: + +https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/ae.safetensors + +For full model training the fp16 version of the main model needs to be used. + +## Why train in ComfyUI? +- Familiar UI (obviously only if you are a Comfy user already) +- You can use same models you use for inference +- You can use same python environment, I faced no incompabilities +- You can build workflows to compare settings etc. + +Currently supports LoRA training, and untested full finetune with code from kohya's scripts: https://github.com/kohya-ss/sd-scripts + +Experimental support for LyCORIS training has been added as well, using code from: https://github.com/KohakuBlueleaf/Lycoris + +![Screenshot 2024-08-21 020207](https://github.com/user-attachments/assets/1686b180-90c8-41d0-8c96-63e76ebc2475) + diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..903b45c39b88fc625deede5eab5ed2adbac9b73a --- /dev/null +++ b/__init__.py @@ -0,0 +1,12 @@ +from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS +from .nodes_sd3 import NODE_CLASS_MAPPINGS as NODE_CLASS_MAPPINGS_SD3 +from .nodes_sd3 import NODE_DISPLAY_NAME_MAPPINGS as NODE_DISPLAY_NAME_MAPPINGS_SD3 +from .nodes_sdxl import NODE_CLASS_MAPPINGS as NODE_CLASS_MAPPINGS_SDXL +from .nodes_sdxl import NODE_DISPLAY_NAME_MAPPINGS as NODE_DISPLAY_NAME_MAPPINGS_SDXL + +NODE_CLASS_MAPPINGS.update(NODE_CLASS_MAPPINGS_SD3) +NODE_CLASS_MAPPINGS.update(NODE_CLASS_MAPPINGS_SDXL) +NODE_DISPLAY_NAME_MAPPINGS.update(NODE_DISPLAY_NAME_MAPPINGS_SD3) +NODE_DISPLAY_NAME_MAPPINGS.update(NODE_DISPLAY_NAME_MAPPINGS_SDXL) + +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] \ No newline at end of file diff --git a/__pycache__/train_network.cpython-310.pyc b/__pycache__/train_network.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50acfd740a845cd766f3f6d5fa75fb6e3b8d91e5 Binary files /dev/null and b/__pycache__/train_network.cpython-310.pyc differ diff --git a/advanced.png b/advanced.png new file mode 100644 index 0000000000000000000000000000000000000000..6fba3b38a0fbcd47249cf0c6efee7ecb8b4cba1d --- /dev/null +++ b/advanced.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:15077625eb185463cc0dd383157879fe3b73ebb7305a40f5ed2af14a49bca41d +size 181656 diff --git a/app-launch.sh b/app-launch.sh new file mode 100644 index 0000000000000000000000000000000000000000..33dcd61e302171534e9ba26f859ea153f2c0605a --- /dev/null +++ b/app-launch.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash + +cd "`dirname "$0"`" || exit 1 +. env/bin/activate +python app.py diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..04378d66893baddebbcbbe7539f3f109ec7d4cad --- /dev/null +++ b/app.py @@ -0,0 +1,1119 @@ +import os +import sys +os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +os.environ['GRADIO_ANALYTICS_ENABLED'] = '0' +sys.path.insert(0, os.getcwd()) +sys.path.append(os.path.join(os.path.dirname(__file__), 'sd-scripts')) +import subprocess +import gradio as gr +from PIL import Image +import torch +import uuid +import shutil +import json +import yaml +from slugify import slugify +from transformers import AutoProcessor, AutoModelForCausalLM +from gradio_logsview import LogsView, LogsViewRunner +from huggingface_hub import hf_hub_download, HfApi +from library import flux_train_utils, huggingface_util +from argparse import Namespace +import train_network +import toml +import re +MAX_IMAGES = 150 + +with open('models.yaml', 'r') as file: + models = yaml.safe_load(file) + +def readme(base_model, lora_name, instance_prompt, sample_prompts): + + # model license + model_config = models[base_model] + model_file = model_config["file"] + base_model_name = model_config["base"] + license = None + license_name = None + license_link = None + license_items = [] + if "license" in model_config: + license = model_config["license"] + license_items.append(f"license: {license}") + if "license_name" in model_config: + license_name = model_config["license_name"] + license_items.append(f"license_name: {license_name}") + if "license_link" in model_config: + license_link = model_config["license_link"] + license_items.append(f"license_link: {license_link}") + license_str = "\n".join(license_items) + print(f"license_items={license_items}") + print(f"license_str = {license_str}") + + # tags + tags = [ "text-to-image", "flux", "lora", "diffusers", "template:sd-lora", "fluxgym" ] + + # widgets + widgets = [] + sample_image_paths = [] + output_name = slugify(lora_name) + samples_dir = resolve_path_without_quotes(f"outputs/{output_name}/sample") + try: + for filename in os.listdir(samples_dir): + # Filename Schema: [name]_[steps]_[index]_[timestamp].png + match = re.search(r"_(\d+)_(\d+)_(\d+)\.png$", filename) + if match: + steps, index, timestamp = int(match.group(1)), int(match.group(2)), int(match.group(3)) + sample_image_paths.append((steps, index, f"sample/{filename}")) + + # Sort by numeric index + sample_image_paths.sort(key=lambda x: x[0], reverse=True) + + final_sample_image_paths = sample_image_paths[:len(sample_prompts)] + final_sample_image_paths.sort(key=lambda x: x[1]) + for i, prompt in enumerate(sample_prompts): + _, _, image_path = final_sample_image_paths[i] + widgets.append( + { + "text": prompt, + "output": { + "url": image_path + }, + } + ) + except: + print(f"no samples") + dtype = "torch.bfloat16" + # Construct the README content + readme_content = f"""--- +tags: +{yaml.dump(tags, indent=4).strip()} +{"widget:" if os.path.isdir(samples_dir) else ""} +{yaml.dump(widgets, indent=4).strip() if widgets else ""} +base_model: {base_model_name} +{"instance_prompt: " + instance_prompt if instance_prompt else ""} +{license_str} +--- + +# {lora_name} + +A Flux LoRA trained on a local computer with [Fluxgym](https://github.com/cocktailpeanut/fluxgym) + + + +## Trigger words + +{"You should use `" + instance_prompt + "` to trigger the image generation." if instance_prompt else "No trigger words defined."} + +## Download model and use it with ComfyUI, AUTOMATIC1111, SD.Next, Invoke AI, Forge, etc. + +Weights for this model are available in Safetensors format. + +""" + return readme_content + +def account_hf(): + try: + with open("HF_TOKEN", "r") as file: + token = file.read() + api = HfApi(token=token) + try: + account = api.whoami() + return { "token": token, "account": account['name'] } + except: + return None + except: + return None + +""" +hf_logout.click(fn=logout_hf, outputs=[hf_token, hf_login, hf_logout, repo_owner]) +""" +def logout_hf(): + os.remove("HF_TOKEN") + global current_account + current_account = account_hf() + print(f"current_account={current_account}") + return gr.update(value=""), gr.update(visible=True), gr.update(visible=False), gr.update(value="", visible=False) + + +""" +hf_login.click(fn=login_hf, inputs=[hf_token], outputs=[hf_token, hf_login, hf_logout, repo_owner]) +""" +def login_hf(hf_token): + api = HfApi(token=hf_token) + try: + account = api.whoami() + if account != None: + if "name" in account: + with open("HF_TOKEN", "w") as file: + file.write(hf_token) + global current_account + current_account = account_hf() + return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(value=current_account["account"], visible=True) + return gr.update(), gr.update(), gr.update(), gr.update() + except: + print(f"incorrect hf_token") + return gr.update(), gr.update(), gr.update(), gr.update() + +def upload_hf(base_model, lora_rows, repo_owner, repo_name, repo_visibility, hf_token): + src = lora_rows + repo_id = f"{repo_owner}/{repo_name}" + gr.Info(f"Uploading to Huggingface. Please Stand by...", duration=None) + args = Namespace( + huggingface_repo_id=repo_id, + huggingface_repo_type="model", + huggingface_repo_visibility=repo_visibility, + huggingface_path_in_repo="", + huggingface_token=hf_token, + async_upload=False + ) + print(f"upload_hf args={args}") + huggingface_util.upload(args=args, src=src) + gr.Info(f"[Upload Complete] https://huggingface.co/{repo_id}", duration=None) + +def load_captioning(uploaded_files, concept_sentence): + uploaded_images = [file for file in uploaded_files if not file.endswith('.txt')] + txt_files = [file for file in uploaded_files if file.endswith('.txt')] + txt_files_dict = {os.path.splitext(os.path.basename(txt_file))[0]: txt_file for txt_file in txt_files} + updates = [] + if len(uploaded_images) <= 1: + raise gr.Error( + "Please upload at least 2 images to train your model (the ideal number with default settings is between 4-30)" + ) + elif len(uploaded_images) > MAX_IMAGES: + raise gr.Error(f"For now, only {MAX_IMAGES} or less images are allowed for training") + # Update for the captioning_area + # for _ in range(3): + updates.append(gr.update(visible=True)) + # Update visibility and image for each captioning row and image + for i in range(1, MAX_IMAGES + 1): + # Determine if the current row and image should be visible + visible = i <= len(uploaded_images) + + # Update visibility of the captioning row + updates.append(gr.update(visible=visible)) + + # Update for image component - display image if available, otherwise hide + image_value = uploaded_images[i - 1] if visible else None + updates.append(gr.update(value=image_value, visible=visible)) + + corresponding_caption = False + if(image_value): + base_name = os.path.splitext(os.path.basename(image_value))[0] + if base_name in txt_files_dict: + with open(txt_files_dict[base_name], 'r') as file: + corresponding_caption = file.read() + + # Update value of captioning area + text_value = corresponding_caption if visible and corresponding_caption else concept_sentence if visible and concept_sentence else None + updates.append(gr.update(value=text_value, visible=visible)) + + # Update for the sample caption area + updates.append(gr.update(visible=True)) + updates.append(gr.update(visible=True)) + + return updates + +def hide_captioning(): + return gr.update(visible=False), gr.update(visible=False) + +def resize_image(image_path, output_path, size): + with Image.open(image_path) as img: + width, height = img.size + if width < height: + new_width = size + new_height = int((size/width) * height) + else: + new_height = size + new_width = int((size/height) * width) + print(f"resize {image_path} : {new_width}x{new_height}") + img_resized = img.resize((new_width, new_height), Image.Resampling.LANCZOS) + img_resized.save(output_path) + +def create_dataset(destination_folder, size, *inputs): + print("Creating dataset") + images = inputs[0] + if not os.path.exists(destination_folder): + os.makedirs(destination_folder) + + for index, image in enumerate(images): + # copy the images to the datasets folder + new_image_path = shutil.copy(image, destination_folder) + + # if it's a caption text file skip the next bit + ext = os.path.splitext(new_image_path)[-1].lower() + if ext == '.txt': + continue + + # resize the images + resize_image(new_image_path, new_image_path, size) + + # copy the captions + + original_caption = inputs[index + 1] + + image_file_name = os.path.basename(new_image_path) + caption_file_name = os.path.splitext(image_file_name)[0] + ".txt" + caption_path = resolve_path_without_quotes(os.path.join(destination_folder, caption_file_name)) + print(f"image_path={new_image_path}, caption_path = {caption_path}, original_caption={original_caption}") + # if caption_path exists, do not write + if os.path.exists(caption_path): + print(f"{caption_path} already exists. use the existing .txt file") + else: + print(f"{caption_path} create a .txt caption file") + with open(caption_path, 'w') as file: + file.write(original_caption) + + print(f"destination_folder {destination_folder}") + return destination_folder + + +def run_captioning(images, concept_sentence, *captions): + print(f"run_captioning") + print(f"concept sentence {concept_sentence}") + print(f"captions {captions}") + #Load internally to not consume resources for training + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"device={device}") + torch_dtype = torch.float16 + model = AutoModelForCausalLM.from_pretrained( + "multimodalart/Florence-2-large-no-flash-attn", torch_dtype=torch_dtype, trust_remote_code=True + ).to(device) + processor = AutoProcessor.from_pretrained("multimodalart/Florence-2-large-no-flash-attn", trust_remote_code=True) + + captions = list(captions) + for i, image_path in enumerate(images): + print(captions[i]) + if isinstance(image_path, str): # If image is a file path + image = Image.open(image_path).convert("RGB") + + prompt = "" + inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype) + print(f"inputs {inputs}") + + generated_ids = model.generate( + input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3 + ) + print(f"generated_ids {generated_ids}") + + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] + print(f"generated_text: {generated_text}") + parsed_answer = processor.post_process_generation( + generated_text, task=prompt, image_size=(image.width, image.height) + ) + print(f"parsed_answer = {parsed_answer}") + caption_text = parsed_answer[""].replace("The image shows ", "") + print(f"caption_text = {caption_text}, concept_sentence={concept_sentence}") + if concept_sentence: + caption_text = f"{concept_sentence} {caption_text}" + captions[i] = caption_text + + yield captions + model.to("cpu") + del model + del processor + if torch.cuda.is_available(): + torch.cuda.empty_cache() + +def recursive_update(d, u): + for k, v in u.items(): + if isinstance(v, dict) and v: + d[k] = recursive_update(d.get(k, {}), v) + else: + d[k] = v + return d + +def download(base_model): + model = models[base_model] + model_file = model["file"] + repo = model["repo"] + + # download unet + if base_model == "flux-dev" or base_model == "flux-schnell": + unet_folder = "models/unet" + else: + unet_folder = f"models/unet/{repo}" + unet_path = os.path.join(unet_folder, model_file) + if not os.path.exists(unet_path): + os.makedirs(unet_folder, exist_ok=True) + gr.Info(f"Downloading base model: {base_model}. Please wait. (You can check the terminal for the download progress)", duration=None) + print(f"download {base_model}") + hf_hub_download(repo_id=repo, local_dir=unet_folder, filename=model_file) + + # download vae + vae_folder = "models/vae" + vae_path = os.path.join(vae_folder, "ae.sft") + if not os.path.exists(vae_path): + os.makedirs(vae_folder, exist_ok=True) + gr.Info(f"Downloading vae") + print(f"downloading ae.sft...") + hf_hub_download(repo_id="cocktailpeanut/xulf-dev", local_dir=vae_folder, filename="ae.sft") + + # download clip + clip_folder = "models/clip" + clip_l_path = os.path.join(clip_folder, "clip_l.safetensors") + if not os.path.exists(clip_l_path): + os.makedirs(clip_folder, exist_ok=True) + gr.Info(f"Downloading clip...") + print(f"download clip_l.safetensors") + hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", local_dir=clip_folder, filename="clip_l.safetensors") + + # download t5xxl + t5xxl_path = os.path.join(clip_folder, "t5xxl_fp16.safetensors") + if not os.path.exists(t5xxl_path): + print(f"download t5xxl_fp16.safetensors") + gr.Info(f"Downloading t5xxl...") + hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", local_dir=clip_folder, filename="t5xxl_fp16.safetensors") + + +def resolve_path(p): + current_dir = os.path.dirname(os.path.abspath(__file__)) + norm_path = os.path.normpath(os.path.join(current_dir, p)) + return f"\"{norm_path}\"" +def resolve_path_without_quotes(p): + current_dir = os.path.dirname(os.path.abspath(__file__)) + norm_path = os.path.normpath(os.path.join(current_dir, p)) + return norm_path + +def gen_sh( + base_model, + output_name, + resolution, + seed, + workers, + learning_rate, + network_dim, + max_train_epochs, + save_every_n_epochs, + timestep_sampling, + guidance_scale, + vram, + sample_prompts, + sample_every_n_steps, + *advanced_components +): + + print(f"gen_sh: network_dim:{network_dim}, max_train_epochs={max_train_epochs}, save_every_n_epochs={save_every_n_epochs}, timestep_sampling={timestep_sampling}, guidance_scale={guidance_scale}, vram={vram}, sample_prompts={sample_prompts}, sample_every_n_steps={sample_every_n_steps}") + + output_dir = resolve_path(f"outputs/{output_name}") + sample_prompts_path = resolve_path(f"outputs/{output_name}/sample_prompts.txt") + + line_break = "\\" + file_type = "sh" + if sys.platform == "win32": + line_break = "^" + file_type = "bat" + + ############# Sample args ######################## + sample = "" + if len(sample_prompts) > 0 and sample_every_n_steps > 0: + sample = f"""--sample_prompts={sample_prompts_path} --sample_every_n_steps="{sample_every_n_steps}" {line_break}""" + + + ############# Optimizer args ######################## +# if vram == "8G": +# optimizer = f"""--optimizer_type adafactor {line_break} +# --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break} +# --split_mode {line_break} +# --network_args "train_blocks=single" {line_break} +# --lr_scheduler constant_with_warmup {line_break} +# --max_grad_norm 0.0 {line_break}""" + if vram == "16G": + # 16G VRAM + optimizer = f"""--optimizer_type adafactor {line_break} + --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break} + --lr_scheduler constant_with_warmup {line_break} + --max_grad_norm 0.0 {line_break}""" + elif vram == "12G": + # 12G VRAM + optimizer = f"""--optimizer_type adafactor {line_break} + --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break} + --split_mode {line_break} + --network_args "train_blocks=single" {line_break} + --lr_scheduler constant_with_warmup {line_break} + --max_grad_norm 0.0 {line_break}""" + else: + # 20G+ VRAM + optimizer = f"--optimizer_type adamw8bit {line_break}" + + + ####################################################### + model_config = models[base_model] + model_file = model_config["file"] + repo = model_config["repo"] + if base_model == "flux-dev" or base_model == "flux-schnell": + model_folder = "models/unet" + else: + model_folder = f"models/unet/{repo}" + model_path = os.path.join(model_folder, model_file) + pretrained_model_path = resolve_path(model_path) + + clip_path = resolve_path("models/clip/clip_l.safetensors") + t5_path = resolve_path("models/clip/t5xxl_fp16.safetensors") + ae_path = resolve_path("models/vae/ae.sft") + sh = f"""accelerate launch {line_break} + --mixed_precision bf16 {line_break} + --num_cpu_threads_per_process 1 {line_break} + sd-scripts/flux_train_network.py {line_break} + --pretrained_model_name_or_path {pretrained_model_path} {line_break} + --clip_l {clip_path} {line_break} + --t5xxl {t5_path} {line_break} + --ae {ae_path} {line_break} + --cache_latents_to_disk {line_break} + --save_model_as safetensors {line_break} + --sdpa --persistent_data_loader_workers {line_break} + --max_data_loader_n_workers {workers} {line_break} + --seed {seed} {line_break} + --gradient_checkpointing {line_break} + --mixed_precision bf16 {line_break} + --save_precision bf16 {line_break} + --network_module networks.lora_flux {line_break} + --network_dim {network_dim} {line_break} + {optimizer}{sample} + --learning_rate {learning_rate} {line_break} + --cache_text_encoder_outputs {line_break} + --cache_text_encoder_outputs_to_disk {line_break} + --fp8_base {line_break} + --highvram {line_break} + --max_train_epochs {max_train_epochs} {line_break} + --save_every_n_epochs {save_every_n_epochs} {line_break} + --dataset_config {resolve_path(f"outputs/{output_name}/dataset.toml")} {line_break} + --output_dir {output_dir} {line_break} + --output_name {output_name} {line_break} + --timestep_sampling {timestep_sampling} {line_break} + --discrete_flow_shift 3.1582 {line_break} + --model_prediction_type raw {line_break} + --guidance_scale {guidance_scale} {line_break} + --loss_type l2 {line_break}""" + + + + ############# Advanced args ######################## + global advanced_component_ids + global original_advanced_component_values + + # check dirty + print(f"original_advanced_component_values = {original_advanced_component_values}") + advanced_flags = [] + for i, current_value in enumerate(advanced_components): +# print(f"compare {advanced_component_ids[i]}: old={original_advanced_component_values[i]}, new={current_value}") + if original_advanced_component_values[i] != current_value: + # dirty + if current_value == True: + # Boolean + advanced_flags.append(advanced_component_ids[i]) + else: + # string + advanced_flags.append(f"{advanced_component_ids[i]} {current_value}") + + if len(advanced_flags) > 0: + advanced_flags_str = f" {line_break}\n ".join(advanced_flags) + sh = sh + "\n " + advanced_flags_str + + return sh + +def gen_toml( + dataset_folder, + resolution, + class_tokens, + num_repeats +): + toml = f"""[general] +shuffle_caption = false +caption_extension = '.txt' +keep_tokens = 1 + +[[datasets]] +resolution = {resolution} +batch_size = 1 +keep_tokens = 1 + + [[datasets.subsets]] + image_dir = '{resolve_path_without_quotes(dataset_folder)}' + class_tokens = '{class_tokens}' + num_repeats = {num_repeats}""" + return toml + +def update_total_steps(max_train_epochs, num_repeats, images): + try: + num_images = len(images) + total_steps = max_train_epochs * num_images * num_repeats + print(f"max_train_epochs={max_train_epochs} num_images={num_images}, num_repeats={num_repeats}, total_steps={total_steps}") + return gr.update(value = total_steps) + except: + print("") + +def set_repo(lora_rows): + selected_name = os.path.basename(lora_rows) + return gr.update(value=selected_name) + +def get_loras(): + try: + outputs_path = resolve_path_without_quotes(f"outputs") + files = os.listdir(outputs_path) + folders = [os.path.join(outputs_path, item) for item in files if os.path.isdir(os.path.join(outputs_path, item)) and item != "sample"] + folders.sort(key=lambda file: os.path.getctime(file), reverse=True) + return folders + except Exception as e: + return [] + +def get_samples(lora_name): + output_name = slugify(lora_name) + try: + samples_path = resolve_path_without_quotes(f"outputs/{output_name}/sample") + files = [os.path.join(samples_path, file) for file in os.listdir(samples_path)] + files.sort(key=lambda file: os.path.getctime(file), reverse=True) + return files + except: + return [] + +def start_training( + base_model, + lora_name, + train_script, + train_config, + sample_prompts, +): + # write custom script and toml + if not os.path.exists("models"): + os.makedirs("models", exist_ok=True) + if not os.path.exists("outputs"): + os.makedirs("outputs", exist_ok=True) + output_name = slugify(lora_name) + output_dir = resolve_path_without_quotes(f"outputs/{output_name}") + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + + download(base_model) + + file_type = "sh" + if sys.platform == "win32": + file_type = "bat" + + sh_filename = f"train.{file_type}" + sh_filepath = resolve_path_without_quotes(f"outputs/{output_name}/{sh_filename}") + with open(sh_filepath, 'w', encoding="utf-8") as file: + file.write(train_script) + gr.Info(f"Generated train script at {sh_filename}") + + + dataset_path = resolve_path_without_quotes(f"outputs/{output_name}/dataset.toml") + with open(dataset_path, 'w', encoding="utf-8") as file: + file.write(train_config) + gr.Info(f"Generated dataset.toml") + + sample_prompts_path = resolve_path_without_quotes(f"outputs/{output_name}/sample_prompts.txt") + with open(sample_prompts_path, 'w', encoding='utf-8') as file: + file.write(sample_prompts) + gr.Info(f"Generated sample_prompts.txt") + + # Train + if sys.platform == "win32": + command = sh_filepath + else: + command = f"bash \"{sh_filepath}\"" + + # Use Popen to run the command and capture output in real-time + env = os.environ.copy() + env['PYTHONIOENCODING'] = 'utf-8' + env['LOG_LEVEL'] = 'DEBUG' + runner = LogsViewRunner() + cwd = os.path.dirname(os.path.abspath(__file__)) + gr.Info(f"Started training") + yield from runner.run_command([command], cwd=cwd) + yield runner.log(f"Runner: {runner}") + + # Generate Readme + config = toml.loads(train_config) + concept_sentence = config['datasets'][0]['subsets'][0]['class_tokens'] + print(f"concept_sentence={concept_sentence}") + print(f"lora_name {lora_name}, concept_sentence={concept_sentence}, output_name={output_name}") + sample_prompts_path = resolve_path_without_quotes(f"outputs/{output_name}/sample_prompts.txt") + with open(sample_prompts_path, "r", encoding="utf-8") as f: + lines = f.readlines() + sample_prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] + md = readme(base_model, lora_name, concept_sentence, sample_prompts) + readme_path = resolve_path_without_quotes(f"outputs/{output_name}/README.md") + with open(readme_path, "w", encoding="utf-8") as f: + f.write(md) + + gr.Info(f"Training Complete. Check the outputs folder for the LoRA files.", duration=None) + + +def update( + base_model, + lora_name, + resolution, + seed, + workers, + class_tokens, + learning_rate, + network_dim, + max_train_epochs, + save_every_n_epochs, + timestep_sampling, + guidance_scale, + vram, + num_repeats, + sample_prompts, + sample_every_n_steps, + *advanced_components, +): + output_name = slugify(lora_name) + dataset_folder = str(f"datasets/{output_name}") + sh = gen_sh( + base_model, + output_name, + resolution, + seed, + workers, + learning_rate, + network_dim, + max_train_epochs, + save_every_n_epochs, + timestep_sampling, + guidance_scale, + vram, + sample_prompts, + sample_every_n_steps, + *advanced_components, + ) + toml = gen_toml( + dataset_folder, + resolution, + class_tokens, + num_repeats + ) + return gr.update(value=sh), gr.update(value=toml), dataset_folder + +""" +demo.load(fn=loaded, js=js, outputs=[hf_token, hf_login, hf_logout, hf_account]) +""" +def loaded(): + global current_account + current_account = account_hf() + print(f"current_account={current_account}") + if current_account != None: + return gr.update(value=current_account["token"]), gr.update(visible=False), gr.update(visible=True), gr.update(value=current_account["account"], visible=True) + else: + return gr.update(value=""), gr.update(visible=True), gr.update(visible=False), gr.update(value="", visible=False) + +def update_sample(concept_sentence): + return gr.update(value=concept_sentence) + +def refresh_publish_tab(): + loras = get_loras() + return gr.Dropdown(label="Trained LoRAs", choices=loras) + +def init_advanced(): + # if basic_args + basic_args = { + 'pretrained_model_name_or_path', + 'clip_l', + 't5xxl', + 'ae', + 'cache_latents_to_disk', + 'save_model_as', + 'sdpa', + 'persistent_data_loader_workers', + 'max_data_loader_n_workers', + 'seed', + 'gradient_checkpointing', + 'mixed_precision', + 'save_precision', + 'network_module', + 'network_dim', + 'learning_rate', + 'cache_text_encoder_outputs', + 'cache_text_encoder_outputs_to_disk', + 'fp8_base', + 'highvram', + 'max_train_epochs', + 'save_every_n_epochs', + 'dataset_config', + 'output_dir', + 'output_name', + 'timestep_sampling', + 'discrete_flow_shift', + 'model_prediction_type', + 'guidance_scale', + 'loss_type', + 'optimizer_type', + 'optimizer_args', + 'lr_scheduler', + 'sample_prompts', + 'sample_every_n_steps', + 'max_grad_norm', + 'split_mode', + 'network_args' + } + + # generate a UI config + # if not in basic_args, create a simple form + parser = train_network.setup_parser() + flux_train_utils.add_flux_train_arguments(parser) + args_info = {} + for action in parser._actions: + if action.dest != 'help': # Skip the default help argument + # if the dest is included in basic_args + args_info[action.dest] = { + "action": action.option_strings, # Option strings like '--use_8bit_adam' + "type": action.type, # Type of the argument + "help": action.help, # Help message + "default": action.default, # Default value, if any + "required": action.required # Whether the argument is required + } + temp = [] + for key in args_info: + temp.append({ 'key': key, 'action': args_info[key] }) + temp.sort(key=lambda x: x['key']) + advanced_component_ids = [] + advanced_components = [] + for item in temp: + key = item['key'] + action = item['action'] + if key in basic_args: + print("") + else: + action_type = str(action['type']) + component = None + with gr.Column(min_width=300): + if action_type == "None": + # radio + component = gr.Checkbox() + # elif action_type == "": + # component = gr.Textbox() + # elif action_type == "": + # component = gr.Number(precision=0) + # elif action_type == "": + # component = gr.Number() + # elif "int_or_float" in action_type: + # component = gr.Number() + else: + component = gr.Textbox(value="") + if component != None: + component.interactive = True + component.elem_id = action['action'][0] + component.label = component.elem_id + component.elem_classes = ["advanced"] + if action['help'] != None: + component.info = action['help'] + advanced_components.append(component) + advanced_component_ids.append(component.elem_id) + return advanced_components, advanced_component_ids + + +theme = gr.themes.Monochrome( + text_size=gr.themes.Size(lg="18px", md="15px", sm="13px", xl="22px", xs="12px", xxl="24px", xxs="9px"), + font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif"], +) +css = """ +@keyframes rotate { + 0% { + transform: rotate(0deg); + } + 100% { + transform: rotate(360deg); + } +} +#advanced_options .advanced:nth-child(even) { background: rgba(0,0,100,0.04) !important; } +h1{font-family: georgia; font-style: italic; font-weight: bold; font-size: 30px; letter-spacing: -1px;} +h3{margin-top: 0} +.tabitem{border: 0px} +.group_padding{} +nav{position: fixed; top: 0; left: 0; right: 0; z-index: 1000; text-align: center; padding: 10px; box-sizing: border-box; display: flex; align-items: center; backdrop-filter: blur(10px); } +nav button { background: none; color: firebrick; font-weight: bold; border: 2px solid firebrick; padding: 5px 10px; border-radius: 5px; font-size: 14px; } +nav img { height: 40px; width: 40px; border-radius: 40px; } +nav img.rotate { animation: rotate 2s linear infinite; } +.flexible { flex-grow: 1; } +.tast-details { margin: 10px 0 !important; } +.toast-wrap { bottom: var(--size-4) !important; top: auto !important; border: none !important; backdrop-filter: blur(10px); } +.toast-title, .toast-text, .toast-icon, .toast-close { color: black !important; font-size: 14px; } +.toast-body { border: none !important; } +#terminal { box-shadow: none !important; margin-bottom: 25px; background: rgba(0,0,0,0.03); } +#terminal .generating { border: none !important; } +#terminal label { position: absolute !important; } +.tabs { margin-top: 50px; } +.hidden { display: none !important; } +.codemirror-wrapper .cm-line { font-size: 12px !important; } +label { font-weight: bold !important; } +#start_training.clicked { background: silver; color: black; } +""" + +js = """ +function() { + let autoscroll = document.querySelector("#autoscroll") + if (window.iidxx) { + window.clearInterval(window.iidxx); + } + window.iidxx = window.setInterval(function() { + let text=document.querySelector(".codemirror-wrapper .cm-line").innerText.trim() + let img = document.querySelector("#logo") + if (text.length > 0) { + autoscroll.classList.remove("hidden") + if (autoscroll.classList.contains("on")) { + autoscroll.textContent = "Autoscroll ON" + window.scrollTo(0, document.body.scrollHeight, { behavior: "smooth" }); + img.classList.add("rotate") + } else { + autoscroll.textContent = "Autoscroll OFF" + img.classList.remove("rotate") + } + } + }, 500); + console.log("autoscroll", autoscroll) + autoscroll.addEventListener("click", (e) => { + autoscroll.classList.toggle("on") + }) + function debounce(fn, delay) { + let timeoutId; + return function(...args) { + clearTimeout(timeoutId); + timeoutId = setTimeout(() => fn(...args), delay); + }; + } + + function handleClick() { + console.log("refresh") + document.querySelector("#refresh").click(); + } + const debouncedClick = debounce(handleClick, 1000); + document.addEventListener("input", debouncedClick); + + document.querySelector("#start_training").addEventListener("click", (e) => { + e.target.classList.add("clicked") + e.target.innerHTML = "Training..." + }) + +} +""" + +current_account = account_hf() +print(f"current_account={current_account}") + +with gr.Blocks(elem_id="app", theme=theme, css=css, fill_width=True) as demo: + with gr.Tabs() as tabs: + with gr.TabItem("Gym"): + output_components = [] + with gr.Row(): + gr.HTML(""" + """) + with gr.Row(elem_id='container'): + with gr.Column(): + gr.Markdown( + """# Step 1. LoRA Info +

Configure your LoRA train settings.

+ """, elem_classes="group_padding") + lora_name = gr.Textbox( + label="The name of your LoRA", + info="This has to be a unique name", + placeholder="e.g.: Persian Miniature Painting style, Cat Toy", + ) + concept_sentence = gr.Textbox( + elem_id="--concept_sentence", + label="Trigger word/sentence", + info="Trigger word or sentence to be used", + placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'", + interactive=True, + ) + model_names = list(models.keys()) + print(f"model_names={model_names}") + base_model = gr.Dropdown(label="Base model (edit the models.yaml file to add more to this list)", choices=model_names, value=model_names[0]) + vram = gr.Radio(["20G", "16G", "12G" ], value="20G", label="VRAM", interactive=True) + num_repeats = gr.Number(value=10, precision=0, label="Repeat trains per image", interactive=True) + max_train_epochs = gr.Number(label="Max Train Epochs", value=16, interactive=True) + total_steps = gr.Number(0, interactive=False, label="Expected training steps") + sample_prompts = gr.Textbox("", lines=5, label="Sample Image Prompts (Separate with new lines)", interactive=True) + sample_every_n_steps = gr.Number(0, precision=0, label="Sample Image Every N Steps", interactive=True) + resolution = gr.Number(value=512, precision=0, label="Resize dataset images") + with gr.Column(): + gr.Markdown( + """# Step 2. Dataset +

Make sure the captions include the trigger word.

+ """, elem_classes="group_padding") + with gr.Group(): + images = gr.File( + file_types=["image", ".txt"], + label="Upload your images", + #info="If you want, you can also manually upload caption files that match the image names (example: img0.png => img0.txt)", + file_count="multiple", + interactive=True, + visible=True, + scale=1, + ) + with gr.Group(visible=False) as captioning_area: + do_captioning = gr.Button("Add AI captions with Florence-2") + output_components.append(captioning_area) + #output_components = [captioning_area] + caption_list = [] + for i in range(1, MAX_IMAGES + 1): + locals()[f"captioning_row_{i}"] = gr.Row(visible=False) + with locals()[f"captioning_row_{i}"]: + locals()[f"image_{i}"] = gr.Image( + type="filepath", + width=111, + height=111, + min_width=111, + interactive=False, + scale=2, + show_label=False, + show_share_button=False, + show_download_button=False, + ) + locals()[f"caption_{i}"] = gr.Textbox( + label=f"Caption {i}", scale=15, interactive=True + ) + + output_components.append(locals()[f"captioning_row_{i}"]) + output_components.append(locals()[f"image_{i}"]) + output_components.append(locals()[f"caption_{i}"]) + caption_list.append(locals()[f"caption_{i}"]) + with gr.Column(): + gr.Markdown( + """# Step 3. Train +

Press start to start training.

+ """, elem_classes="group_padding") + refresh = gr.Button("Refresh", elem_id="refresh", visible=False) + start = gr.Button("Start training", visible=False, elem_id="start_training") + output_components.append(start) + train_script = gr.Textbox(label="Train script", max_lines=100, interactive=True) + train_config = gr.Textbox(label="Train config", max_lines=100, interactive=True) + with gr.Accordion("Advanced options", elem_id='advanced_options', open=False): + with gr.Row(): + with gr.Column(min_width=300): + seed = gr.Number(label="--seed", info="Seed", value=42, interactive=True) + with gr.Column(min_width=300): + workers = gr.Number(label="--max_data_loader_n_workers", info="Number of Workers", value=2, interactive=True) + with gr.Column(min_width=300): + learning_rate = gr.Textbox(label="--learning_rate", info="Learning Rate", value="8e-4", interactive=True) + with gr.Column(min_width=300): + save_every_n_epochs = gr.Number(label="--save_every_n_epochs", info="Save every N epochs", value=4, interactive=True) + with gr.Column(min_width=300): + guidance_scale = gr.Number(label="--guidance_scale", info="Guidance Scale", value=1.0, interactive=True) + with gr.Column(min_width=300): + timestep_sampling = gr.Textbox(label="--timestep_sampling", info="Timestep Sampling", value="shift", interactive=True) + with gr.Column(min_width=300): + network_dim = gr.Number(label="--network_dim", info="LoRA Rank", value=4, minimum=4, maximum=128, step=4, interactive=True) + advanced_components, advanced_component_ids = init_advanced() + with gr.Row(): + terminal = LogsView(label="Train log", elem_id="terminal") + with gr.Row(): + gallery = gr.Gallery(get_samples, inputs=[lora_name], label="Samples", every=10, columns=6) + + with gr.TabItem("Publish") as publish_tab: + hf_token = gr.Textbox(label="Huggingface Token") + hf_login = gr.Button("Login") + hf_logout = gr.Button("Logout") + with gr.Row() as row: + gr.Markdown("**LoRA**") + gr.Markdown("**Upload**") + loras = get_loras() + with gr.Row(): + lora_rows = refresh_publish_tab() + with gr.Column(): + with gr.Row(): + repo_owner = gr.Textbox(label="Account", interactive=False) + repo_name = gr.Textbox(label="Repository Name") + repo_visibility = gr.Textbox(label="Repository Visibility ('public' or 'private')", value="public") + upload_button = gr.Button("Upload to HuggingFace") + upload_button.click( + fn=upload_hf, + inputs=[ + base_model, + lora_rows, + repo_owner, + repo_name, + repo_visibility, + hf_token, + ] + ) + hf_login.click(fn=login_hf, inputs=[hf_token], outputs=[hf_token, hf_login, hf_logout, repo_owner]) + hf_logout.click(fn=logout_hf, outputs=[hf_token, hf_login, hf_logout, repo_owner]) + + + publish_tab.select(refresh_publish_tab, outputs=lora_rows) + lora_rows.select(fn=set_repo, inputs=[lora_rows], outputs=[repo_name]) + + dataset_folder = gr.State() + + listeners = [ + base_model, + lora_name, + resolution, + seed, + workers, + concept_sentence, + learning_rate, + network_dim, + max_train_epochs, + save_every_n_epochs, + timestep_sampling, + guidance_scale, + vram, + num_repeats, + sample_prompts, + sample_every_n_steps, + *advanced_components + ] + advanced_component_ids = [x.elem_id for x in advanced_components] + original_advanced_component_values = [comp.value for comp in advanced_components] + images.upload( + load_captioning, + inputs=[images, concept_sentence], + outputs=output_components + ) + images.delete( + load_captioning, + inputs=[images, concept_sentence], + outputs=output_components + ) + images.clear( + hide_captioning, + outputs=[captioning_area, start] + ) + max_train_epochs.change( + fn=update_total_steps, + inputs=[max_train_epochs, num_repeats, images], + outputs=[total_steps] + ) + num_repeats.change( + fn=update_total_steps, + inputs=[max_train_epochs, num_repeats, images], + outputs=[total_steps] + ) + images.upload( + fn=update_total_steps, + inputs=[max_train_epochs, num_repeats, images], + outputs=[total_steps] + ) + images.delete( + fn=update_total_steps, + inputs=[max_train_epochs, num_repeats, images], + outputs=[total_steps] + ) + images.clear( + fn=update_total_steps, + inputs=[max_train_epochs, num_repeats, images], + outputs=[total_steps] + ) + concept_sentence.change(fn=update_sample, inputs=[concept_sentence], outputs=sample_prompts) + start.click(fn=create_dataset, inputs=[dataset_folder, resolution, images] + caption_list, outputs=dataset_folder).then( + fn=start_training, + inputs=[ + base_model, + lora_name, + train_script, + train_config, + sample_prompts, + ], + outputs=terminal, + ) + do_captioning.click(fn=run_captioning, inputs=[images, concept_sentence] + caption_list, outputs=caption_list) + demo.load(fn=loaded, js=js, outputs=[hf_token, hf_login, hf_logout, repo_owner]) + refresh.click(update, inputs=listeners, outputs=[train_script, train_config, dataset_folder]) +if __name__ == "__main__": + cwd = os.path.dirname(os.path.abspath(__file__)) + demo.launch(debug=True, show_error=True, allowed_paths=[cwd]) diff --git a/datasets/1 b/datasets/1 new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..998b18251a7854d2d03de682bb39dd2ede796b51 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,28 @@ +services: + + fluxgym: + build: + context: . + # change the dockerfile to Dockerfile.cuda12.4 if you are running CUDA 12.4 drivers otherwise leave as is + dockerfile: Dockerfile + image: fluxgym + container_name: fluxgym + ports: + - 7860:7860 + environment: + - PUID=${PUID:-1000} + - PGID=${PGID:-1000} + volumes: + - /etc/localtime:/etc/localtime:ro + - /etc/timezone:/etc/timezone:ro + - ./:/app/fluxgym + stop_signal: SIGKILL + tty: true + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + restart: unless-stopped \ No newline at end of file diff --git a/fine_tune.py b/fine_tune.py new file mode 100644 index 0000000000000000000000000000000000000000..ae7d9c8aa1523da692f3dbe70ab2181694908fdf --- /dev/null +++ b/fine_tune.py @@ -0,0 +1,560 @@ +# training with captions +# XXX dropped option: hypernetwork training + +import argparse +import math +import os +from multiprocessing import Value +import toml + +from tqdm import tqdm + +import torch +from library import deepspeed_utils, strategy_base +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate.utils import set_seed +from diffusers import DDPMScheduler + +from .utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import ( + apply_snr_weight, + get_weighted_text_embeddings, + prepare_scheduler_for_custom_training, + scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, +) +import library.strategy_sd as strategy_sd + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + cache_latents = args.cache_latents + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if cache_latents: + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # モデルを読み込む + text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator) + + # verify load/save model formats + if load_stable_diffusion_format: + src_stable_diffusion_ckpt = args.pretrained_model_name_or_path + src_diffusers_model_path = None + else: + src_stable_diffusion_ckpt = None + src_diffusers_model_path = args.pretrained_model_name_or_path + + if args.save_model_as is None: + save_stable_diffusion_format = load_stable_diffusion_format + use_safetensors = args.use_safetensors + else: + save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors" + use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) + + # Diffusers版のxformers使用フラグを設定する関数 + def set_diffusers_xformers_flag(model, valid): + # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう + # pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`) + # U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか + # 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^) + + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + fn_recursive_set_mem_eff(model) + + # モデルに xformers とか memory efficient attention を組み込む + if args.diffusers_xformers: + accelerator.print("Use xformers by Diffusers") + set_diffusers_xformers_flag(unet, True) + else: + # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある + accelerator.print("Disable Diffusers' xformers") + set_diffusers_xformers_flag(unet, False) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + + train_dataset_group.new_cache_latents(vae, accelerator) + + vae.to("cpu") + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # 学習を準備する:モデルを適切な状態にする + training_models = [] + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + training_models.append(unet) + + if args.train_text_encoder: + accelerator.print("enable text encoder training") + if args.gradient_checkpointing: + text_encoder.gradient_checkpointing_enable() + training_models.append(text_encoder) + else: + text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.requires_grad_(False) # text encoderは学習しない + if args.gradient_checkpointing: + text_encoder.gradient_checkpointing_enable() + text_encoder.train() # required for gradient_checkpointing + else: + text_encoder.eval() + + text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=vae_dtype) + + for m in training_models: + m.requires_grad_(True) + + trainable_params = [] + if args.learning_rate_te is None or not args.train_text_encoder: + for m in training_models: + trainable_params.extend(m.parameters()) + else: + trainable_params = [ + {"params": list(unet.parameters()), "lr": args.learning_rate}, + {"params": list(text_encoder.parameters()), "lr": args.learning_rate_te}, + ] + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + _, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + unet.to(weight_dtype) + text_encoder.to(weight_dtype) + + if args.deepspeed: + if args.train_text_encoder: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder) + else: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet) + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + else: + # acceleratorがなんかよろしくやってくれるらしい + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + accelerator.print( + f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + # For --sample_at_first + train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) + + loss_recorder = train_util.LossRecorder() + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(*training_models): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(weight_dtype) + latents = latents * 0.18215 + b_size = latents.shape[0] + + with torch.set_grad_enabled(args.train_text_encoder): + # Get the text embedding for conditioning + if args.weighted_captions: + # TODO move to strategy_sd.py + encoder_hidden_states = get_weighted_text_embeddings( + tokenize_strategy.tokenizer, + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + input_ids = batch["input_ids_list"][0].to(accelerator.device) + encoder_hidden_states = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder], [input_ids] + )[0] + if args.full_fp16: + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: + # do not mean over batch dimension for snr weight or scale v-pred loss + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) + loss = loss.mean([1, 2, 3]) + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + + loss = loss.mean() # mean over batch dimension + else: + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c + ) + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(text_encoder), + accelerator.unwrap_model(unet), + vae, + ) + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if args.logging_dir is not None: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(text_encoder), + accelerator.unwrap_model(unet), + vae, + ) + + train_util.sample_images( + accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) + + is_main_process = accelerator.is_main_process + if is_main_process: + unet = accelerator.unwrap_model(unet) + text_encoder = accelerator.unwrap_model(text_encoder) + + accelerator.end_training() + + if is_main_process and (args.save_state or args.save_state_on_train_end): + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_train_end( + args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae + ) + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, False, True, True) + train_util.add_training_arguments(parser, False) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + + parser.add_argument( + "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する" + ) + parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + parser.add_argument( + "--learning_rate_te", + type=float, + default=None, + help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ", + ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/flags.png b/flags.png new file mode 100644 index 0000000000000000000000000000000000000000..049b127bbdffd1318ec0b70ffc309d81d76a08c8 Binary files /dev/null and b/flags.png differ diff --git a/flow.gif b/flow.gif new file mode 100644 index 0000000000000000000000000000000000000000..6f54af79b26e1b01f1ce82b0252525947c5ad4b4 --- /dev/null +++ b/flow.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e502e5bcbfd25f5d7bad10e0b57a88c8f3b24006792d3a273d7bd964634a8fd9 +size 11349766 diff --git a/flux_extract_lora.py b/flux_extract_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..661cbba5b052d26fe4a09b72715f108171992110 --- /dev/null +++ b/flux_extract_lora.py @@ -0,0 +1,221 @@ +# extract approximating LoRA by svd from two FLUX models +# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py +# Thanks to cloneofsimo! + +import argparse +import json +import os +import time +import torch +from safetensors.torch import load_file, save_file +from safetensors import safe_open +from tqdm import tqdm +from .library import flux_utils, sai_model_spec +from .library.utils import MemoryEfficientSafeOpen +from .library.utils import setup_logging +from .networks import lora_flux + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from comfy.utils import ProgressBar +# CLAMP_QUANTILE = 0.99 +# MIN_DIFF = 1e-1 + + +def save_to_file(file_name, state_dict, metadata, dtype): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + save_file(state_dict, file_name, metadata=metadata) + + +def svd( + model_org=None, + model_tuned=None, + save_to=None, + dim=4, + device=None, + store_device='cpu', + save_precision=None, + clamp_quantile=0.99, + min_diff=0.01, + no_metadata=False, + mem_eff_safe_open=False, +): + def str_to_dtype(p): + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 + return None + + calc_dtype = torch.float + save_dtype = str_to_dtype(save_precision) + + # open models + lora_weights = {} + if not mem_eff_safe_open: + # use original safetensors.safe_open + open_fn = lambda fn: safe_open(fn, framework="pt") + else: + logger.info("Using memory efficient safe_open") + open_fn = lambda fn: MemoryEfficientSafeOpen(fn) + + with open_fn(model_org) as fo: + # filter keys + keys = [] + for key in fo.keys(): + if not ("single_block" in key or "double_block" in key): + continue + if ".bias" in key: + continue + if "norm" in key: + continue + keys.append(key) + comfy_pbar = ProgressBar(len(keys)) + with open_fn(model_tuned) as ft: + for key in tqdm(keys): + # get tensors and calculate difference + value_o = fo.get_tensor(key) + value_t = ft.get_tensor(key) + mat = value_t.to(calc_dtype) - value_o.to(calc_dtype) + del value_o, value_t + + # extract LoRA weights + if device: + mat = mat.to(device) + out_dim, in_dim = mat.size()[0:2] + rank = min(dim, in_dim, out_dim) # LoRA rank cannot exceed the original dim + + mat = mat.squeeze() + + U, S, Vh = torch.linalg.svd(mat) + + U = U[:, :rank] + S = S[:rank] + U = U @ torch.diag(S) + + Vh = Vh[:rank, :] + + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, clamp_quantile) + low_val = -hi_val + + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + + U = U.to(store_device, dtype=save_dtype).contiguous() + Vh = Vh.to(store_device, dtype=save_dtype).contiguous() + + print(f"key: {key}, U: {U.size()}, Vh: {Vh.size()}") + comfy_pbar.update(1) + lora_weights[key] = (U, Vh) + del mat, U, S, Vh + + # make state dict for LoRA + lora_sd = {} + for key, (up_weight, down_weight) in lora_weights.items(): + lora_name = key.replace(".weight", "").replace(".", "_") + lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + lora_name + lora_sd[lora_name + ".lora_up.weight"] = up_weight + lora_sd[lora_name + ".lora_down.weight"] = down_weight + lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0]) # same as rank + + # minimum metadata + net_kwargs = {} + metadata = { + "ss_v2": str(False), + "ss_base_model_version": flux_utils.MODEL_VERSION_FLUX_V1, + "ss_network_module": "networks.lora_flux", + "ss_network_dim": str(dim), + "ss_network_alpha": str(float(dim)), + "ss_network_args": json.dumps(net_kwargs), + } + + if not no_metadata: + title = os.path.splitext(os.path.basename(save_to))[0] + sai_metadata = sai_model_spec.build_metadata(lora_sd, False, False, False, True, False, time.time(), title, flux="dev") + metadata.update(sai_metadata) + + save_to_file(save_to, lora_sd, metadata, save_dtype) + + logger.info(f"LoRA weights saved to {save_to}") + return save_to + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat", + ) + parser.add_argument( + "--model_org", + type=str, + default=None, + required=True, + help="Original model: safetensors file / 元モデル、safetensors", + ) + parser.add_argument( + "--model_tuned", + type=str, + default=None, + required=True, + help="Tuned model, LoRA is difference of `original to tuned`: safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors", + ) + parser.add_argument( + "--mem_eff_safe_open", + action="store_true", + help="use memory efficient safe_open. This is an experimental feature, use only when memory is not enough." + " / メモリ効率の良いsafe_openを使用する。実装は実験的なものなので、メモリが足りない場合のみ使用してください。", + ) + parser.add_argument( + "--save_to", + type=str, + default=None, + required=True, + help="destination file name: safetensors file / 保存先のファイル名、safetensors", + ) + parser.add_argument( + "--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)" + ) + parser.add_argument( + "--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う" + ) + parser.add_argument( + "--clamp_quantile", + type=float, + default=0.99, + help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99", + ) + # parser.add_argument( + # "--min_diff", + # type=float, + # default=0.01, + # help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /" + # + "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01", + # ) + parser.add_argument( + "--no_metadata", + action="store_true", + help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + svd(**vars(args)) \ No newline at end of file diff --git a/flux_train_comfy.py b/flux_train_comfy.py new file mode 100644 index 0000000000000000000000000000000000000000..6251ccc473d40f762c657684ffb3a7f17980a352 --- /dev/null +++ b/flux_train_comfy.py @@ -0,0 +1,806 @@ +# training with captions + +# Swap blocks between CPU and GPU: +# This implementation is inspired by and based on the work of 2kpr. +# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading. +# The original idea has been adapted and extended to fit the current project's needs. + +# Key features: +# - CPU offloading during forward and backward passes +# - Use of fused optimizer and grad_hook for efficient gradient processing +# - Per-block fused optimizer instances + +import argparse +import copy +import math +import os +from multiprocessing import Value +import toml + +from tqdm import tqdm + +import torch +from .library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate.utils import set_seed +from .library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux +from .library.sd3_train_utils import FlowMatchEulerDiscreteScheduler + +from .library import train_util as train_util + +from .library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from .library import config_util as config_util + +from .library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +from .library.custom_train_functions import apply_masked_loss, add_custom_train_arguments + + +class FluxTrainer: + def __init__(self): + self.sample_prompts_te_outputs = None + + def sample_images(self, epoch, global_step, validation_settings): + image_tensors = flux_train_utils.sample_images( + self.accelerator, self.args, epoch, global_step, self.unet, self.vae, self.text_encoder, self.sample_prompts_te_outputs, validation_settings) + return image_tensors + + def init_train(self, args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + # sdxl_train_util.verify_sdxl_training_args(args) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + # temporary: backward compatibility for deprecated options. remove in the future + if not args.skip_cache_check: + args.skip_cache_check = args.skip_latents_validity_check + + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cpu_offload_checkpointing and not args.gradient_checkpointing: + logger.warning( + "cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります" + ) + args.gradient_checkpointing = True + + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, ( + "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + ) + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + if args.cache_latents: + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # Prepare the dataset + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 + + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + if args.debug_dataset: + if args.cache_text_encoder_outputs: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( + strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False + ) + ) + t5xxl_max_token_length = ( + args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512) + ) + strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)) + + train_dataset_group.set_current_strategies() + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # load VAE for caching latents + ae = None + if cache_latents: + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + ae.to(accelerator.device, dtype=weight_dtype) + ae.requires_grad_(False) + ae.eval() + + train_dataset_group.new_cache_latents(ae, accelerator) + + ae.to("cpu") # if no sampling, vae can be deleted + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # prepare tokenize strategy + if args.t5xxl_max_token_length is None: + if is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + flux_tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length) + strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy) + + # load clip_l, t5xxl for caching text encoder outputs + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors) + clip_l.eval() + t5xxl.eval() + clip_l.requires_grad_(False) + t5xxl.requires_grad_(False) + + text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # cache text encoder outputs + sample_prompts_te_outputs = None + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad here + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask + ) + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) + + with accelerator.autocast(): + train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator) + + # cache sample prompt's embeddings to free text encoder's memory + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = [] + for line in args.sample_prompts: + line = line.strip() + if len(line) > 0 and line[0] != "#": + prompts.append(line) + + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + from .library.train_util import line_to_prompt_dict + + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = flux_tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs + accelerator.wait_for_everyone() + + # now we can delete Text Encoders to free memory + clip_l = None + t5xxl = None + clean_memory_on_device(accelerator.device) + + # load FLUX + _, flux = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors + ) + + if args.gradient_checkpointing: + flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing) + + flux.requires_grad_(True) + + # block swap + + # backward compatibility + if args.blocks_to_swap is None: + blocks_to_swap = args.double_blocks_to_swap or 0 + if args.single_blocks_to_swap is not None: + blocks_to_swap += args.single_blocks_to_swap // 2 + if blocks_to_swap > 0: + logger.warning( + "double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead." + " / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。" + ) + logger.info( + f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}." + ) + args.blocks_to_swap = blocks_to_swap + del blocks_to_swap + + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if self.is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + # This idea is based on 2kpr's great work. Thank you! + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + flux.enable_block_swap(args.blocks_to_swap, accelerator.device) + + if not cache_latents: + # load VAE here if not cached + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu") + ae.requires_grad_(False) + ae.eval() + ae.to(accelerator.device, dtype=weight_dtype) + + training_models = [] + params_to_optimize = [] + training_models.append(flux) + name_and_params = list(flux.named_parameters()) + # single param group for now + params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate}) + param_names = [[n for n, _ in name_and_params]] + + # calculate number of trainable parameters + n_params = 0 + for group in params_to_optimize: + for p in group["params"]: + n_params += p.numel() + + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + if args.blockwise_fused_optimizers: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters. + # This balances memory usage and management complexity. + + # split params into groups. currently different learning rates are not supported + grouped_params = [] + param_group = {} + for group in params_to_optimize: + named_parameters = list(flux.named_parameters()) + assert len(named_parameters) == len(group["params"]), "number of parameters does not match" + for p, np in zip(group["params"], named_parameters): + # determine target layer and block index for each parameter + block_type = "other" # double, single or other + if np[0].startswith("double_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "double" + elif np[0].startswith("single_blocks"): + block_index = int(np[0].split(".")[1]) + block_type = "single" + else: + block_index = -1 + + param_group_key = (block_type, block_index) + if param_group_key not in param_group: + param_group[param_group_key] = [] + param_group[param_group_key].append(p) + + block_types_and_indices = [] + for param_group_key, param_group in param_group.items(): + block_types_and_indices.append(param_group_key) + grouped_params.append({"params": param_group, "lr": args.learning_rate}) + + num_params = 0 + for p in param_group: + num_params += p.numel() + accelerator.print(f"block {param_group_key}: {num_params} parameters") + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers") + + if train_util.is_schedulefree_optimizer(optimizers[0], args): + raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers") + self.optimizer_train_fn = lambda: None # dummy function + self.optimizer_eval_fn = lambda: None # dummy function + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + self.optimizer_train_fn, self.optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + if args.blockwise_fused_optimizers: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + flux.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) # TODO check works with fp16 or not + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + flux.to(weight_dtype) + if clip_l is not None: + clip_l.to(weight_dtype) + t5xxl.to(weight_dtype) + + # if we don't cache text encoder outputs, move them to device + if not args.cache_text_encoder_outputs: + clip_l.to(accelerator.device) + t5xxl.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + + if args.deepspeed: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=flux) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + # accelerator does some magic + # if we doesn't swap blocks, we can move the model to device + flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks]) + if self.is_swapping_blocks: + accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + from .library import adafactor_fused + + adafactor_fused.patch_adafactor_fused(optimizer) + + for param_group, param_name_group in zip(optimizer.param_groups, param_names): + for parameter, param_name in zip(param_group["params"], param_name_group): + if parameter.requires_grad: + + def create_grad_hook(p_name, p_group): + def grad_hook(tensor: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, p_group) + tensor.grad = None + + return grad_hook + + parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) + + elif args.blockwise_fused_optimizers: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def grad_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(grad_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training") + accelerator.print(f" num examples: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch: {len(train_dataloader)}") + accelerator.print(f" num epochs: {num_train_epochs}") + accelerator.print( + f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print(f" gradient accumulation steps = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + self.global_step = 0 + + noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + if self.is_swapping_blocks: + accelerator.unwrap_model(flux).prepare_block_swap_before_forward() + + # For --sample_at_first + #flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) + + self.loss_recorder = train_util.LossRecorder() + epoch = 0 # avoid error when max_train_steps is 0 + + self.tokens_and_masks = tokens_and_masks + self.num_train_epochs = num_train_epochs + self.current_epoch = current_epoch + self.args = args + self.accelerator = accelerator + self.unet = flux + self.vae = ae + self.text_encoder = [clip_l, t5xxl] + self.save_dtype = save_dtype + + def training_loop(break_at_steps, epoch): + global optimizer_hooked_count + steps_done = 0 + #accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + progress_bar.set_description(f"Epoch {epoch + 1}/{num_train_epochs} - steps") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = self.global_step + + if args.blockwise_fused_optimizers: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step + + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list + else: + # not cached or training, so get from text encoders + tokens_and_masks = batch["input_ids_list"] + with torch.no_grad(): + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype + ) + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + # get guidance: ensure args.guidance_scale is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) + + # call model + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + if args.bypass_flux_guidance: + flux_utils.bypass_flux_guidance(flux) + + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = flux( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + if args.bypass_flux_guidance: + flux_utils.restore_flux_guidance(flux) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # calculate loss + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + loss = loss.mean() + + # backward + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.blockwise_fused_optimizers: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + self.global_step += 1 + + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if len(accelerator.trackers) > 0: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) + + accelerator.log(logs, step=self.global_step) + + self.loss_recorder.add(epoch=epoch, step=step, loss=current_loss, global_step=self.global_step) + avr_loss: float = self.loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if self.global_step >= break_at_steps: + break + steps_done += 1 + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": self.loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + return steps_done + + return training_loop + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) # TODO split this + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + add_custom_train_arguments(parser) # TODO remove this from here + train_util.add_dit_training_arguments(parser) + flux_train_utils.add_flux_train_arguments(parser) + + parser.add_argument( + "--mem_eff_save", + action="store_true", + help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う", + ) + + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます", + ) + parser.add_argument( + "--blockwise_fused_optimizers", + action="store_true", + help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする", + ) + parser.add_argument( + "--skip_latents_validity_check", + action="store_true", + help="skip latents validity check / latentsの正当性チェックをスキップする", + ) + + parser.add_argument( + "--cpu_offload_checkpointing", + action="store_true", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", + ) + return parser diff --git a/flux_train_network_comfy.py b/flux_train_network_comfy.py new file mode 100644 index 0000000000000000000000000000000000000000..8cbbd9edccd3dcec4bda5bbae0fb4dd82e90816e --- /dev/null +++ b/flux_train_network_comfy.py @@ -0,0 +1,500 @@ +import torch +import copy +import math +from typing import Any, Dict, List, Optional, Tuple, Union +import argparse +from .library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util +from .train_network import NetworkTrainer, clean_memory_on_device, setup_parser + +from accelerate import Accelerator + + +import logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +class FluxNetworkTrainer(NetworkTrainer): + def __init__(self): + super().__init__() + self.sample_prompts_te_outputs = None + self.is_schnell: Optional[bool] = None + self.is_swapping_blocks: bool = False + + def assert_extra_args(self, args, train_dataset_group): + super().assert_extra_args(args, train_dataset_group) + # sdxl_train_util.verify_sdxl_training_args(args) + + if args.fp8_base_unet: + args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1 + + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # prepare CLIP-L/T5XXL training flags + self.train_clip_l = not args.network_train_unet_only + self.train_t5xxl = False # default is False even if args.network_train_unet_only is False + + if args.max_token_length is not None: + logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + + def load_target_model(self, args, weight_dtype, accelerator): + # currently offload to cpu for some models + + # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) + loading_dtype = None if args.fp8_base else weight_dtype + + # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future + self.is_schnell, model = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors + ) + if args.fp8_base: + # check dtype of model + if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") + elif model.dtype == torch.float8_e4m3fn or model.dtype == torch.float8_e5m2: + logger.info(f"Loaded {model.dtype} FLUX model") + + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if self.is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + model.enable_block_swap(args.blocks_to_swap, accelerator.device) + + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + clip_l.eval() + + # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) + if args.fp8_base and not args.fp8_base_unet: + loading_dtype = None # as is + else: + loading_dtype = weight_dtype + + # loading t5xxl to cpu takes a long time, so we should load to gpu in future + t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + t5xxl.eval() + if args.fp8_base and not args.fp8_base_unet: + # check dtype of model + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") + + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + + return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + + def get_tokenize_strategy(self, args): + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + + if args.t5xxl_max_token_length is None: + if is_schnell: + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") + return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy): + return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) + + def post_process_network(self, args, accelerator, network, text_encoders, unet): + # check t5xxl is trained or not + self.train_t5xxl = network.train_t5xxl + + if self.train_t5xxl and args.cache_text_encoder_outputs: + raise ValueError( + "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません" + ) + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + if args.cache_text_encoder_outputs: + if self.train_clip_l and not self.train_t5xxl: + return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached + else: + return None # no text encoders are needed for encoding because both are cached + else: + return text_encoders # both CLIP-L and T5XXL are needed for encoding + + def get_text_encoders_train_flags(self, args, text_encoders): + return [self.train_clip_l, self.train_t5xxl] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + # if the text encoders is trained, we need tokenization, so is_partial is True + return strategy_flux.FluxTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=self.train_clip_l or self.train_t5xxl, + apply_t5_attn_mask=args.apply_t5_attn_mask, + ) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + ): + if args.cache_text_encoder_outputs: + if not args.lowram: + # reduce memory consumption + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + logger.info("move text encoders to gpu") + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[1].to(accelerator.device) + + if text_encoders[1].dtype == torch.float8_e4m3fn: + # if we load fp8 weights, the model is already fp8, so we use it as is + self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) + else: + # otherwise, we need to convert it to target dtype + text_encoders[1].to(weight_dtype) + + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) + + # cache sample prompts + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = [] + for line in args.sample_prompts: + line = line.strip() + if len(line) > 0 and line[0] != "#": + prompts.append(line) + + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + from .library.train_util import line_to_prompt_dict + + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask + ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs + + accelerator.wait_for_everyone() + + # move back to cpu + if not self.is_train_text_encoder(args): + logger.info("move CLIP-L back to cpu") + text_encoders[0].to("cpu") + logger.info("move t5XXL back to cpu") + text_encoders[1].to("cpu") + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoder + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device) + + def sample_images(self, epoch, global_step, validation_settings): + text_encoders = self.get_models_for_text_encoding(self.args, self.accelerator, self.text_encoder) + + image_tensors = flux_train_utils.sample_images( + self.accelerator, self.args, epoch, global_step, self.unet, self.vae, text_encoders, self.sample_prompts_te_outputs, validation_settings) + clean_memory_on_device(self.accelerator.device) + return image_tensors + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images) + + def shift_scale_latents(self, args, latents): + return latents + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet: flux_models.Flux, + network, + weight_dtype, + train_unet, + ): + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + ) + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + # get guidance + # ensure guidance_scale in args is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + for t in text_encoder_conds: + if t is not None and t.dtype.is_floating_point: + t.requires_grad_(True) + img_ids.requires_grad_(True) + guidance_vec.requires_grad_(True) + + # Predict the noise residual + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): + # normal forward + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = unet( + img=img, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + """ + else: + # split forward to reduce memory usage + assert network.train_blocks == "single", "train_blocks must be single for split mode" + with accelerator.autocast(): + # move flux lower to cpu, and then move flux upper to gpu + unet.to("cpu") + clean_memory_on_device(accelerator.device) + self.flux_upper.to(accelerator.device) + + # upper model does not require grad + with torch.no_grad(): + intermediate_img, intermediate_txt, vec, pe = self.flux_upper( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + # move flux upper back to cpu, and then move flux lower to gpu + self.flux_upper.to("cpu") + clean_memory_on_device(accelerator.device) + unet.to(accelerator.device) + + # lower model requires grad + intermediate_img.requires_grad_(True) + intermediate_txt.requires_grad_(True) + vec.requires_grad_(True) + pe.requires_grad_(True) + model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) + """ + + return model_pred + + if args.bypass_flux_guidance: + flux_utils.bypass_flux_guidance(unet) + + model_pred = call_dit( + img=packed_noisy_model_input, + img_ids=img_ids, + t5_out=t5_out, + txt_ids=txt_ids, + l_pooled=l_pooled, + timesteps=timesteps, + guidance_vec=guidance_vec, + t5_attn_mask=t5_attn_mask, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + if args.bypass_flux_guidance: #for flex + flux_utils.restore_flux_guidance(unet) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + unet.prepare_block_swap_before_forward() + with torch.no_grad(): + model_pred_prior = call_dit( + img=packed_noisy_model_input[diff_output_pr_indices], + img_ids=img_ids[diff_output_pr_indices], + t5_out=t5_out[diff_output_pr_indices], + txt_ids=txt_ids[diff_output_pr_indices], + l_pooled=l_pooled[diff_output_pr_indices], + timesteps=timesteps[diff_output_pr_indices], + guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None, + t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None, + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + + model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width) + model_pred_prior, _ = flux_train_utils.apply_model_prediction_type( + args, + model_pred_prior, + noisy_model_input[diff_output_pr_indices], + sigmas[diff_output_pr_indices] if sigmas is not None else None, + ) + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + + return model_pred, target, timesteps, weighting + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + return loss + + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev") + + def update_metadata(self, metadata, args): + metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask + metadata["ss_weighting_scheme"] = args.weighting_scheme + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + metadata["ss_guidance_scale"] = args.guidance_scale + metadata["ss_timestep_sampling"] = args.timestep_sampling + metadata["ss_sigmoid_scale"] = args.sigmoid_scale + metadata["ss_model_prediction_type"] = args.model_prediction_type + metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift + + def is_text_encoder_not_needed_for_training(self, args): + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + if index == 0: # CLIP-L + return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder) + else: # T5XXL + text_encoder.encoder.embed_tokens.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + if index == 0: # CLIP-L + logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") + text_encoder.to(te_weight_dtype) # fp8 + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + else: # T5XXL + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: + logger.info(f"T5XXL already prepared for fp8") + else: + logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") + text_encoder.to(te_weight_dtype) # fp8 + prepare_fp8(text_encoder, weight_dtype) + + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # if we doesn't swap blocks, we can move the model to device + flux: flux_models.Flux = unet + flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(flux).prepare_block_swap_before_forward() + + return flux + + +def setup_parser() -> argparse.ArgumentParser: + parser = setup_parser() + train_util.add_dit_training_arguments(parser) + flux_train_utils.add_flux_train_arguments(parser) diff --git a/hf_token.json b/hf_token.json new file mode 100644 index 0000000000000000000000000000000000000000..ab7d5c76a4ead5351ee7e95329a48677a3501605 --- /dev/null +++ b/hf_token.json @@ -0,0 +1,3 @@ +{ + "hf_token": "your_token_here" +} \ No newline at end of file diff --git a/icon.png b/icon.png new file mode 100644 index 0000000000000000000000000000000000000000..c763bed52203d1c9f56607c01dc28b3a7f1eee2d Binary files /dev/null and b/icon.png differ diff --git a/install.js b/install.js new file mode 100644 index 0000000000000000000000000000000000000000..d3d911125808d634cf5a8794e93f3f6d7791c46b --- /dev/null +++ b/install.js @@ -0,0 +1,96 @@ +module.exports = { + run: [ + { + method: "shell.run", + params: { + venv: "env", + message: [ + "git config --global --add safe.directory '*'", + "git clone -b sd3 https://github.com/kohya-ss/sd-scripts" + ] + } + }, + { + method: "shell.run", + params: { + path: "sd-scripts", + venv: "../env", + message: [ + "uv pip install -r requirements.txt", + ] + } + }, + { + method: "shell.run", + params: { + venv: "env", + message: [ + "pip uninstall -y diffusers[torch] torch torchaudio torchvision", + "uv pip install -r requirements.txt", + ] + } + }, + { + method: "script.start", + params: { + uri: "torch.js", + params: { + venv: "env", + // xformers: true // uncomment this line if your project requires xformers + } + } + }, + { + method: "fs.link", + params: { + drive: { + vae: "models/vae", + clip: "models/clip", + unet: "models/unet", + loras: "outputs", + }, + peers: [ + "https://github.com/pinokiofactory/stable-diffusion-webui-forge.git", + "https://github.com/pinokiofactory/comfy.git", + "https://github.com/cocktailpeanutlabs/comfyui.git", + "https://github.com/cocktailpeanutlabs/fooocus.git", + "https://github.com/cocktailpeanutlabs/automatic1111.git", + ] + } + }, +// { +// method: "fs.download", +// params: { +// uri: [ +// "https://huggingface.co/comfyanonymous/flux_text_encoders/resolve/main/clip_l.safetensors?download=true", +// "https://huggingface.co/comfyanonymous/flux_text_encoders/resolve/main/t5xxl_fp16.safetensors?download=true", +// ], +// dir: "models/clip" +// } +// }, +// { +// method: "fs.download", +// params: { +// uri: [ +// "https://huggingface.co/cocktailpeanut/xulf-dev/resolve/main/ae.sft?download=true", +// ], +// dir: "models/vae" +// } +// }, +// { +// method: "fs.download", +// params: { +// uri: [ +// "https://huggingface.co/cocktailpeanut/xulf-dev/resolve/main/flux1-dev.sft?download=true", +// ], +// dir: "models/unet" +// } +// }, + { + method: "fs.link", + params: { + venv: "env" + } + } + ] +} diff --git a/library/__init__.py b/library/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/library/__pycache__/__init__.cpython-310.pyc b/library/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4434c91f0781a8c54a5ec928892548bef8dd3555 Binary files /dev/null and b/library/__pycache__/__init__.cpython-310.pyc differ diff --git a/library/__pycache__/config_util.cpython-310.pyc b/library/__pycache__/config_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7af25c53b8bddd90ce54163c1e3d3b56354a5f4d Binary files /dev/null and b/library/__pycache__/config_util.cpython-310.pyc differ diff --git a/library/__pycache__/custom_offloading_utils.cpython-310.pyc b/library/__pycache__/custom_offloading_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d5415a1e497b01444439652630b5f4b7a49c8cd Binary files /dev/null and b/library/__pycache__/custom_offloading_utils.cpython-310.pyc differ diff --git a/library/__pycache__/custom_train_functions.cpython-310.pyc b/library/__pycache__/custom_train_functions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ec9879a42194bd300ddca5f5b5b6360272b587b Binary files /dev/null and b/library/__pycache__/custom_train_functions.cpython-310.pyc differ diff --git a/library/__pycache__/deepspeed_utils.cpython-310.pyc b/library/__pycache__/deepspeed_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ba3ef9a24c69ab1ba0731314ea8da98d8fa75d0 Binary files /dev/null and b/library/__pycache__/deepspeed_utils.cpython-310.pyc differ diff --git a/library/__pycache__/device_utils.cpython-310.pyc b/library/__pycache__/device_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..972b8250302570e149242ea63616cdd3efb558bc Binary files /dev/null and b/library/__pycache__/device_utils.cpython-310.pyc differ diff --git a/library/__pycache__/flux_models.cpython-310.pyc b/library/__pycache__/flux_models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56ce1d4ea96448680e9c47a306efe1b420a01e7e Binary files /dev/null and b/library/__pycache__/flux_models.cpython-310.pyc differ diff --git a/library/__pycache__/flux_train_utils.cpython-310.pyc b/library/__pycache__/flux_train_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8232f6bb95aef764011a49a6b3c37850d32ea8da Binary files /dev/null and b/library/__pycache__/flux_train_utils.cpython-310.pyc differ diff --git a/library/__pycache__/flux_utils.cpython-310.pyc b/library/__pycache__/flux_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf8997d908fcc5aa7dd2c97cf15daf72f2a8dcbd Binary files /dev/null and b/library/__pycache__/flux_utils.cpython-310.pyc differ diff --git a/library/__pycache__/huggingface_util.cpython-310.pyc b/library/__pycache__/huggingface_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf0c380c4a08fd6a19cc64d05794479e4e93be51 Binary files /dev/null and b/library/__pycache__/huggingface_util.cpython-310.pyc differ diff --git a/library/__pycache__/model_util.cpython-310.pyc b/library/__pycache__/model_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..614429b5176628baee564a8fe07a6c11487b0f04 Binary files /dev/null and b/library/__pycache__/model_util.cpython-310.pyc differ diff --git a/library/__pycache__/original_unet.cpython-310.pyc b/library/__pycache__/original_unet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40324b150ca486e128b08b1006fd117b15a302e1 Binary files /dev/null and b/library/__pycache__/original_unet.cpython-310.pyc differ diff --git a/library/__pycache__/sai_model_spec.cpython-310.pyc b/library/__pycache__/sai_model_spec.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea3d3d1c582b76697080f6314c2d9ff5850fe4b8 Binary files /dev/null and b/library/__pycache__/sai_model_spec.cpython-310.pyc differ diff --git a/library/__pycache__/sd3_models.cpython-310.pyc b/library/__pycache__/sd3_models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80484c7229bd39c4f072bc6fa04b3288a34943ca Binary files /dev/null and b/library/__pycache__/sd3_models.cpython-310.pyc differ diff --git a/library/__pycache__/sd3_utils.cpython-310.pyc b/library/__pycache__/sd3_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2007e8adf9309e66d0b16b84c684099618b99612 Binary files /dev/null and b/library/__pycache__/sd3_utils.cpython-310.pyc differ diff --git a/library/__pycache__/strategy_base.cpython-310.pyc b/library/__pycache__/strategy_base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b000e3252e6ae0f8a06a8a86604be57399f799fe Binary files /dev/null and b/library/__pycache__/strategy_base.cpython-310.pyc differ diff --git a/library/__pycache__/strategy_sd.cpython-310.pyc b/library/__pycache__/strategy_sd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67f8c7430846716a0c47828f8d1a155c6772f68f Binary files /dev/null and b/library/__pycache__/strategy_sd.cpython-310.pyc differ diff --git a/library/__pycache__/train_util.cpython-310.pyc b/library/__pycache__/train_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..156c9a9a8e2cdac3e52c8ad85df43259f0a285f4 --- /dev/null +++ b/library/__pycache__/train_util.cpython-310.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa71a44895d0a006e41ba9fadbd0177a9ad5499cc89aeb2266aa1c7a9597e82e +size 164434 diff --git a/library/__pycache__/utils.cpython-310.pyc b/library/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62837bec9f40aacaacc0075a2f70745ce7c84049 Binary files /dev/null and b/library/__pycache__/utils.cpython-310.pyc differ diff --git a/library/adafactor_fused.py b/library/adafactor_fused.py new file mode 100644 index 0000000000000000000000000000000000000000..b5afa236bc8d76779d5f90e6db9d52cf2261bdb3 --- /dev/null +++ b/library/adafactor_fused.py @@ -0,0 +1,138 @@ +import math +import torch +from transformers import Adafactor + +# stochastic rounding for bfloat16 +# The implementation was provided by 2kpr. Thank you very much! + +def copy_stochastic_(target: torch.Tensor, source: torch.Tensor): + """ + copies source into target using stochastic rounding + + Args: + target: the target tensor with dtype=bfloat16 + source: the target tensor with dtype=float32 + """ + # create a random 16 bit integer + result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16)) + + # add the random number to the lower 16 bit of the mantissa + result.add_(source.view(dtype=torch.int32)) + + # mask off the lower 16 bit of the mantissa + result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32 + + # copy the higher 16 bit into the target tensor + target.copy_(result.view(dtype=torch.float32)) + + del result + + +@torch.no_grad() +def adafactor_step_param(self, p, group): + if p.grad is None: + return + grad = p.grad + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = Adafactor._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) + state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state["step"] += 1 + state["RMS"] = Adafactor._rms(p_data_fp32) + lr = Adafactor._get_lr(group, state) + + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad**2) + group["eps"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) + + # Approximation of exponential moving average of square of gradient + update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) + + p_data_fp32.add_(-update) + + # if p.dtype in {torch.float16, torch.bfloat16}: + # p.copy_(p_data_fp32) + + if p.dtype == torch.bfloat16: + copy_stochastic_(p, p_data_fp32) + elif p.dtype == torch.float16: + p.copy_(p_data_fp32) + + +@torch.no_grad() +def adafactor_step(self, closure=None): + """ + Performs a single optimization step + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + adafactor_step_param(self, p, group) + + return loss + + +def patch_adafactor_fused(optimizer: Adafactor): + optimizer.step_param = adafactor_step_param.__get__(optimizer) + optimizer.step = adafactor_step.__get__(optimizer) diff --git a/library/attention_processors.py b/library/attention_processors.py new file mode 100644 index 0000000000000000000000000000000000000000..310c2cb1c63955f8f03296c54fd47c21f1a981c9 --- /dev/null +++ b/library/attention_processors.py @@ -0,0 +1,227 @@ +import math +from typing import Any +from einops import rearrange +import torch +from diffusers.models.attention_processor import Attention + + +# flash attention forwards and backwards + +# https://arxiv.org/abs/2205.14135 + +EPSILON = 1e-6 + + +class FlashAttentionFunction(torch.autograd.function.Function): + @staticmethod + @torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """Algorithm 2 in the paper""" + + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full( + (*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device + ) + + scale = q.shape[-1] ** -0.5 + + if mask is None: + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, "b n -> b 1 1 n") + mask = mask.split(q_bucket_size, dim=-1) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = ( + torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale + ) + + if row_mask is not None: + attn_weights.masked_fill_(~row_mask, max_neg_value) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones( + (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device + ).triu(q_start_index - k_start_index + 1) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) + + if row_mask is not None: + exp_weights.masked_fill_(~row_mask, 0.0) + + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp( + min=EPSILON + ) + + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + + exp_values = torch.einsum( + "... i j, ... j d -> ... i d", exp_weights, vc + ) + + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) + + new_row_sums = ( + exp_row_max_diff * row_sums + + exp_block_row_max_diff * block_row_sums + ) + + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_( + (exp_block_row_max_diff / new_row_sums) * exp_values + ) + + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) + + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + + return o + + @staticmethod + @torch.no_grad() + def backward(ctx, do): + """Algorithm 4 in the paper""" + + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors + + device = q.device + + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = ( + torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale + ) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones( + (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device + ).triu(q_start_index - k_start_index + 1) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + exp_attn_weights = torch.exp(attn_weights - mc) + + if row_mask is not None: + exp_attn_weights.masked_fill_(~row_mask, 0.0) + + p = exp_attn_weights / lc + + dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc) + dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc) + + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) + + dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc) + dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc) + + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) + + return dq, dk, dv, None, None, None, None + + +class FlashAttnProcessor: + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ) -> Any: + q_bucket_size = 512 + k_bucket_size = 1024 + + h = attn.heads + q = attn.to_q(hidden_states) + + encoder_hidden_states = ( + encoder_hidden_states + if encoder_hidden_states is not None + else hidden_states + ) + encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype) + + if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None: + context_k, context_v = attn.hypernetwork.forward( + hidden_states, encoder_hidden_states + ) + context_k = context_k.to(hidden_states.dtype) + context_v = context_v.to(hidden_states.dtype) + else: + context_k = encoder_hidden_states + context_v = encoder_hidden_states + + k = attn.to_k(context_k) + v = attn.to_v(context_v) + del encoder_hidden_states, hidden_states + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + out = FlashAttentionFunction.apply( + q, k, v, attention_mask, False, q_bucket_size, k_bucket_size + ) + + out = rearrange(out, "b h n d -> b n (h d)") + + out = attn.to_out[0](out) + out = attn.to_out[1](out) + return out diff --git a/library/config_util.py b/library/config_util.py new file mode 100644 index 0000000000000000000000000000000000000000..458e2fe81ab78b3dcc02bee09a74ededa9895de2 --- /dev/null +++ b/library/config_util.py @@ -0,0 +1,717 @@ +import argparse +from dataclasses import ( + asdict, + dataclass, +) +import functools +import random +from textwrap import dedent, indent +import json +from pathlib import Path + +# from toolz import curry +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import toml +import voluptuous +from voluptuous import ( + Any, + ExactSequence, + MultipleInvalid, + Object, + Required, + Schema, +) + + +from . import train_util +from .train_util import ( + DreamBoothSubset, + FineTuningSubset, + ControlNetSubset, + DreamBoothDataset, + FineTuningDataset, + ControlNetDataset, + DatasetGroup, +) +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def add_config_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル" + ) + + +# TODO: inherit Params class in Subset, Dataset + + +@dataclass +class BaseSubsetParams: + image_dir: Optional[str] = None + num_repeats: int = 1 + shuffle_caption: bool = False + caption_separator: str = (",",) + keep_tokens: int = 0 + keep_tokens_separator: str = (None,) + secondary_separator: Optional[str] = None + enable_wildcard: bool = False + color_aug: bool = False + flip_aug: bool = False + face_crop_aug_range: Optional[Tuple[float, float]] = None + random_crop: bool = False + caption_prefix: Optional[str] = None + caption_suffix: Optional[str] = None + caption_dropout_rate: float = 0.0 + caption_dropout_every_n_epochs: int = 0 + caption_tag_dropout_rate: float = 0.0 + token_warmup_min: int = 1 + token_warmup_step: float = 0 + custom_attributes: Optional[Dict[str, Any]] = None + + +@dataclass +class DreamBoothSubsetParams(BaseSubsetParams): + is_reg: bool = False + class_tokens: Optional[str] = None + caption_extension: str = ".caption" + cache_info: bool = False + alpha_mask: bool = False + + +@dataclass +class FineTuningSubsetParams(BaseSubsetParams): + metadata_file: Optional[str] = None + alpha_mask: bool = False + + +@dataclass +class ControlNetSubsetParams(BaseSubsetParams): + conditioning_data_dir: str = None + caption_extension: str = ".caption" + cache_info: bool = False + + +@dataclass +class BaseDatasetParams: + resolution: Optional[Tuple[int, int]] = None + network_multiplier: float = 1.0 + debug_dataset: bool = False + + +@dataclass +class DreamBoothDatasetParams(BaseDatasetParams): + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + prior_loss_weight: float = 1.0 + + +@dataclass +class FineTuningDatasetParams(BaseDatasetParams): + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + + +@dataclass +class ControlNetDatasetParams(BaseDatasetParams): + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + + +@dataclass +class SubsetBlueprint: + params: Union[DreamBoothSubsetParams, FineTuningSubsetParams] + + +@dataclass +class DatasetBlueprint: + is_dreambooth: bool + is_controlnet: bool + params: Union[DreamBoothDatasetParams, FineTuningDatasetParams] + subsets: Sequence[SubsetBlueprint] + + +@dataclass +class DatasetGroupBlueprint: + datasets: Sequence[DatasetBlueprint] + + +@dataclass +class Blueprint: + dataset_group: DatasetGroupBlueprint + + +class ConfigSanitizer: + # @curry + @staticmethod + def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple: + Schema(ExactSequence([klass, klass]))(value) + return tuple(value) + + # @curry + @staticmethod + def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple: + Schema(Any(klass, ExactSequence([klass, klass])))(value) + try: + Schema(klass)(value) + return (value, value) + except: + return ConfigSanitizer.__validate_and_convert_twodim(klass, value) + + # subset schema + SUBSET_ASCENDABLE_SCHEMA = { + "color_aug": bool, + "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float), + "flip_aug": bool, + "num_repeats": int, + "random_crop": bool, + "shuffle_caption": bool, + "keep_tokens": int, + "keep_tokens_separator": str, + "secondary_separator": str, + "caption_separator": str, + "enable_wildcard": bool, + "token_warmup_min": int, + "token_warmup_step": Any(float, int), + "caption_prefix": str, + "caption_suffix": str, + "custom_attributes": dict, + } + # DO means DropOut + DO_SUBSET_ASCENDABLE_SCHEMA = { + "caption_dropout_every_n_epochs": int, + "caption_dropout_rate": Any(float, int), + "caption_tag_dropout_rate": Any(float, int), + } + # DB means DreamBooth + DB_SUBSET_ASCENDABLE_SCHEMA = { + "caption_extension": str, + "class_tokens": str, + "cache_info": bool, + } + DB_SUBSET_DISTINCT_SCHEMA = { + Required("image_dir"): str, + "is_reg": bool, + "alpha_mask": bool, + } + # FT means FineTuning + FT_SUBSET_DISTINCT_SCHEMA = { + Required("metadata_file"): str, + "image_dir": str, + "alpha_mask": bool, + } + CN_SUBSET_ASCENDABLE_SCHEMA = { + "caption_extension": str, + "cache_info": bool, + } + CN_SUBSET_DISTINCT_SCHEMA = { + Required("image_dir"): str, + Required("conditioning_data_dir"): str, + } + + # datasets schema + DATASET_ASCENDABLE_SCHEMA = { + "batch_size": int, + "bucket_no_upscale": bool, + "bucket_reso_steps": int, + "enable_bucket": bool, + "max_bucket_reso": int, + "min_bucket_reso": int, + "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), + "network_multiplier": float, + } + + # options handled by argparse but not handled by user config + ARGPARSE_SPECIFIC_SCHEMA = { + "debug_dataset": bool, + "max_token_length": Any(None, int), + "prior_loss_weight": Any(float, int), + } + # for handling default None value of argparse + ARGPARSE_NULLABLE_OPTNAMES = [ + "face_crop_aug_range", + "resolution", + ] + # prepare map because option name may differ among argparse and user config + ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = { + "train_batch_size": "batch_size", + "dataset_repeats": "num_repeats", + } + + def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None: + assert support_dreambooth or support_finetuning or support_controlnet, ( + "Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more." + + " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。" + ) + + self.db_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_DISTINCT_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.ft_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.FT_SUBSET_DISTINCT_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.cn_subset_schema = self.__merge_dict( + self.SUBSET_ASCENDABLE_SCHEMA, + self.CN_SUBSET_DISTINCT_SCHEMA, + self.CN_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.db_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.db_subset_schema]}, + ) + + self.ft_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.ft_subset_schema]}, + ) + + self.cn_dataset_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.CN_SUBSET_ASCENDABLE_SCHEMA, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + {"subsets": [self.cn_subset_schema]}, + ) + + if support_dreambooth and support_finetuning: + + def validate_flex_dataset(dataset_config: dict): + subsets_config = dataset_config.get("subsets", []) + + if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]): + return Schema(self.cn_dataset_schema)(dataset_config) + # check dataset meets FT style + # NOTE: all FT subsets should have "metadata_file" + elif all(["metadata_file" in subset for subset in subsets_config]): + return Schema(self.ft_dataset_schema)(dataset_config) + # check dataset meets DB style + # NOTE: all DB subsets should have no "metadata_file" + elif all(["metadata_file" not in subset for subset in subsets_config]): + return Schema(self.db_dataset_schema)(dataset_config) + else: + raise voluptuous.Invalid( + "DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。" + ) + + self.dataset_schema = validate_flex_dataset + elif support_dreambooth: + if support_controlnet: + self.dataset_schema = self.cn_dataset_schema + else: + self.dataset_schema = self.db_dataset_schema + elif support_finetuning: + self.dataset_schema = self.ft_dataset_schema + elif support_controlnet: + self.dataset_schema = self.cn_dataset_schema + + self.general_schema = self.__merge_dict( + self.DATASET_ASCENDABLE_SCHEMA, + self.SUBSET_ASCENDABLE_SCHEMA, + self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {}, + self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {}, + self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {}, + ) + + self.user_config_validator = Schema( + { + "general": self.general_schema, + "datasets": [self.dataset_schema], + } + ) + + self.argparse_schema = self.__merge_dict( + self.general_schema, + self.ARGPARSE_SPECIFIC_SCHEMA, + {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES}, + {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()}, + ) + + self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA) + + def sanitize_user_config(self, user_config: dict) -> dict: + try: + return self.user_config_validator(user_config) + except MultipleInvalid: + # TODO: エラー発生時のメッセージをわかりやすくする + logger.error("Invalid user config / ユーザ設定の形式が正しくないようです") + raise + + # NOTE: In nature, argument parser result is not needed to be sanitize + # However this will help us to detect program bug + def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace: + try: + return self.argparse_config_validator(argparse_namespace) + except MultipleInvalid: + # XXX: this should be a bug + logger.error( + "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。" + ) + raise + + # NOTE: value would be overwritten by latter dict if there is already the same key + @staticmethod + def __merge_dict(*dict_list: dict) -> dict: + merged = {} + for schema in dict_list: + # merged |= schema + for k, v in schema.items(): + merged[k] = v + return merged + + +class BlueprintGenerator: + BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {} + + def __init__(self, sanitizer: ConfigSanitizer): + self.sanitizer = sanitizer + + # runtime_params is for parameters which is only configurable on runtime, such as tokenizer + def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint: + sanitized_user_config = self.sanitizer.sanitize_user_config(user_config) + sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace) + + # convert argparse namespace to dict like config + # NOTE: it is ok to have extra entries in dict + optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME + argparse_config = { + optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items() + } + + general_config = sanitized_user_config.get("general", {}) + + dataset_blueprints = [] + for dataset_config in sanitized_user_config.get("datasets", []): + # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets + subsets = dataset_config.get("subsets", []) + is_dreambooth = all(["metadata_file" not in subset for subset in subsets]) + is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets]) + if is_controlnet: + subset_params_klass = ControlNetSubsetParams + dataset_params_klass = ControlNetDatasetParams + elif is_dreambooth: + subset_params_klass = DreamBoothSubsetParams + dataset_params_klass = DreamBoothDatasetParams + else: + subset_params_klass = FineTuningSubsetParams + dataset_params_klass = FineTuningDatasetParams + + subset_blueprints = [] + for subset_config in subsets: + params = self.generate_params_by_fallbacks( + subset_params_klass, [subset_config, dataset_config, general_config, argparse_config, runtime_params] + ) + subset_blueprints.append(SubsetBlueprint(params)) + + params = self.generate_params_by_fallbacks( + dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params] + ) + dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints)) + + dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints) + + return Blueprint(dataset_group_blueprint) + + @staticmethod + def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]): + name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME + search_value = BlueprintGenerator.search_value + default_params = asdict(param_klass()) + param_names = default_params.keys() + + params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names} + + return param_klass(**params) + + @staticmethod + def search_value(key: str, fallbacks: Sequence[dict], default_value=None): + for cand in fallbacks: + value = cand.get(key) + if value is not None: + return value + + return default_value + + +def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint): + datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = [] + + for dataset_blueprint in dataset_group_blueprint.datasets: + if dataset_blueprint.is_controlnet: + subset_klass = ControlNetSubset + dataset_klass = ControlNetDataset + elif dataset_blueprint.is_dreambooth: + subset_klass = DreamBoothSubset + dataset_klass = DreamBoothDataset + else: + subset_klass = FineTuningSubset + dataset_klass = FineTuningDataset + + subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets] + dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params)) + datasets.append(dataset) + + # print info + info = "" + for i, dataset in enumerate(datasets): + is_dreambooth = isinstance(dataset, DreamBoothDataset) + is_controlnet = isinstance(dataset, ControlNetDataset) + info += dedent( + f"""\ + [Dataset {i}] + batch_size: {dataset.batch_size} + resolution: {(dataset.width, dataset.height)} + enable_bucket: {dataset.enable_bucket} + network_multiplier: {dataset.network_multiplier} + """ + ) + + if dataset.enable_bucket: + info += indent( + dedent( + f"""\ + min_bucket_reso: {dataset.min_bucket_reso} + max_bucket_reso: {dataset.max_bucket_reso} + bucket_reso_steps: {dataset.bucket_reso_steps} + bucket_no_upscale: {dataset.bucket_no_upscale} + \n""" + ), + " ", + ) + else: + info += "\n" + + for j, subset in enumerate(dataset.subsets): + info += indent( + dedent( + f"""\ + [Subset {j} of Dataset {i}] + image_dir: "{subset.image_dir}" + image_count: {subset.img_count} + num_repeats: {subset.num_repeats} + shuffle_caption: {subset.shuffle_caption} + keep_tokens: {subset.keep_tokens} + keep_tokens_separator: {subset.keep_tokens_separator} + caption_separator: {subset.caption_separator} + secondary_separator: {subset.secondary_separator} + enable_wildcard: {subset.enable_wildcard} + caption_dropout_rate: {subset.caption_dropout_rate} + caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs} + caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} + caption_prefix: {subset.caption_prefix} + caption_suffix: {subset.caption_suffix} + color_aug: {subset.color_aug} + flip_aug: {subset.flip_aug} + face_crop_aug_range: {subset.face_crop_aug_range} + random_crop: {subset.random_crop} + token_warmup_min: {subset.token_warmup_min} + token_warmup_step: {subset.token_warmup_step} + alpha_mask: {subset.alpha_mask} + custom_attributes: {subset.custom_attributes} + """ + ), + " ", + ) + + if is_dreambooth: + info += indent( + dedent( + f"""\ + is_reg: {subset.is_reg} + class_tokens: {subset.class_tokens} + caption_extension: {subset.caption_extension} + \n""" + ), + " ", + ) + elif not is_controlnet: + info += indent( + dedent( + f"""\ + metadata_file: {subset.metadata_file} + \n""" + ), + " ", + ) + + logger.info(f"{info}") + + # make buckets first because it determines the length of dataset + # and set the same seed for all datasets + seed = random.randint(0, 2**31) # actual seed is seed + epoch_no + for i, dataset in enumerate(datasets): + logger.info(f"[Dataset {i}]") + dataset.make_buckets() + dataset.set_seed(seed) + + return DatasetGroup(datasets) + + +def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None): + def extract_dreambooth_params(name: str) -> Tuple[int, str]: + tokens = name.split("_") + try: + n_repeats = int(tokens[0]) + except ValueError as e: + logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}") + return 0, "" + caption_by_folder = "_".join(tokens[1:]) + return n_repeats, caption_by_folder + + def generate(base_dir: Optional[str], is_reg: bool): + if base_dir is None: + return [] + + base_dir: Path = Path(base_dir) + if not base_dir.is_dir(): + return [] + + subsets_config = [] + for subdir in base_dir.iterdir(): + if not subdir.is_dir(): + continue + + num_repeats, class_tokens = extract_dreambooth_params(subdir.name) + if num_repeats < 1: + continue + + subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens} + subsets_config.append(subset_config) + + return subsets_config + + subsets_config = [] + subsets_config += generate(train_data_dir, False) + subsets_config += generate(reg_data_dir, True) + + return subsets_config + + +def generate_controlnet_subsets_config_by_subdirs( + train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt" +): + def generate(base_dir: Optional[str]): + if base_dir is None: + return [] + + base_dir: Path = Path(base_dir) + if not base_dir.is_dir(): + return [] + + subsets_config = [] + subset_config = { + "image_dir": train_data_dir, + "conditioning_data_dir": conditioning_data_dir, + "caption_extension": caption_extension, + "num_repeats": 1, + } + subsets_config.append(subset_config) + + return subsets_config + + subsets_config = [] + subsets_config += generate(train_data_dir) + + return subsets_config + + +def load_user_config(file: str) -> dict: + file_path: Path = Path(file) + if not file_path.is_file(): + #raise ValueError(f"file not found / ファイルが見つかりません: {file}") + return toml.loads(file) + + if file_path.name.lower().endswith(".json"): + try: + with open(file, "r") as f: + config = json.load(f) + except Exception: + logger.error( + f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) + raise + elif file_path.name.lower().endswith(".toml"): + try: + config = toml.load(file_path) + except Exception: + logger.error( + f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}" + ) + raise + else: + raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file_path}") + + return config + + +# for config test +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--support_dreambooth", action="store_true") + parser.add_argument("--support_finetuning", action="store_true") + parser.add_argument("--support_controlnet", action="store_true") + parser.add_argument("--support_dropout", action="store_true") + parser.add_argument("dataset_config") + config_args, remain = parser.parse_known_args() + + parser = argparse.ArgumentParser() + train_util.add_dataset_arguments( + parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout + ) + train_util.add_training_arguments(parser, config_args.support_dreambooth) + argparse_namespace = parser.parse_args(remain) + train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) + + logger.info("[argparse_namespace]") + logger.info(f"{vars(argparse_namespace)}") + + user_config = load_user_config(config_args.dataset_config) + + logger.info("") + logger.info("[user_config]") + logger.info(f"{user_config}") + + sanitizer = ConfigSanitizer( + config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout + ) + sanitized_user_config = sanitizer.sanitize_user_config(user_config) + + logger.info("") + logger.info("[sanitized_user_config]") + logger.info(f"{sanitized_user_config}") + + blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) + + logger.info("") + logger.info("[blueprint]") + logger.info(f"{blueprint}") \ No newline at end of file diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0565ea4cc1c2b114414b5cc09089caf1c4a0bbcc --- /dev/null +++ b/library/custom_offloading_utils.py @@ -0,0 +1,227 @@ +from concurrent.futures import ThreadPoolExecutor +import time +from typing import Optional +import torch +import torch.nn as nn + +from .device_utils import clean_memory_on_device + + +def synchronize_device(device: torch.device): + if device.type == "cuda": + torch.cuda.synchronize() + elif device.type == "xpu": + torch.xpu.synchronize() + elif device.type == "mps": + torch.mps.synchronize() + + +def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + + # This is not working for all cases (e.g. SD3), so we need to find the corresponding modules + # for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + # print(module_to_cpu.__class__, module_to_cuda.__class__) + # if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + # weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()} + for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules(): + if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None: + module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None) + if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + else: + if module_to_cuda.weight.data.device.type != device.type: + # print( + # f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device" + # ) + module_to_cuda.weight.data = module_to_cuda.weight.data.to(device) + + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + # cuda to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.record_stream(stream) + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + + stream.synchronize() + + # cpu to cuda + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + stream.synchronize() + torch.cuda.current_stream().synchronize() # this prevents the illegal loss value + + +def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): + """ + not tested + """ + assert layer_to_cpu.__class__ == layer_to_cuda.__class__ + + weight_swap_jobs = [] + for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): + if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: + weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) + + # device to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + + synchronize_device() + + # cpu to device + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view + + synchronize_device() + + +def weighs_to_device(layer: nn.Module, device: torch.device): + for module in layer.modules(): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data = module.weight.data.to(device, non_blocking=True) + + +class Offloader: + """ + common offloading class + """ + + def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): + self.num_blocks = num_blocks + self.blocks_to_swap = blocks_to_swap + self.device = device + self.debug = debug + + self.thread_pool = ThreadPoolExecutor(max_workers=1) + self.futures = {} + self.cuda_available = device.type == "cuda" + + def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module): + if self.cuda_available: + swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda) + else: + swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda) + + def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda): + def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): + if self.debug: + start_time = time.perf_counter() + print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}") + + self.swap_weight_devices(block_to_cpu, block_to_cuda) + + if self.debug: + print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s") + return bidx_to_cpu, bidx_to_cuda # , event + + block_to_cpu = blocks[block_idx_to_cpu] + block_to_cuda = blocks[block_idx_to_cuda] + + self.futures[block_idx_to_cuda] = self.thread_pool.submit( + move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda + ) + + def _wait_blocks_move(self, block_idx): + if block_idx not in self.futures: + return + + if self.debug: + print(f"Wait for block {block_idx}") + start_time = time.perf_counter() + + future = self.futures.pop(block_idx) + _, bidx_to_cuda = future.result() + + assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}" + + if self.debug: + print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s") + + +class ModelOffloader(Offloader): + """ + supports forward offloading + """ + + def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): + super().__init__(num_blocks, blocks_to_swap, device, debug) + + # register backward hooks + self.remove_handles = [] + for i, block in enumerate(blocks): + hook = self.create_backward_hook(blocks, i) + if hook is not None: + handle = block.register_full_backward_hook(hook) + self.remove_handles.append(handle) + + def __del__(self): + for handle in self.remove_handles: + handle.remove() + + def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]: + # -1 for 0-based index + num_blocks_propagated = self.num_blocks - block_index - 1 + swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap + waiting = block_index > 0 and block_index <= self.blocks_to_swap + + if not swapping and not waiting: + return None + + # create hook + block_idx_to_cpu = self.num_blocks - num_blocks_propagated + block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated + block_idx_to_wait = block_index - 1 + + def backward_hook(module, grad_input, grad_output): + if self.debug: + print(f"Backward hook for block {block_index}") + + if swapping: + self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) + if waiting: + self._wait_blocks_move(block_idx_to_wait) + return None + + return backward_hook + + def prepare_block_devices_before_forward(self, blocks: list[nn.Module]): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + + if self.debug: + print("Prepare block devices before forward") + + for b in blocks[0 : self.num_blocks - self.blocks_to_swap]: + b.to(self.device) + weighs_to_device(b, self.device) # make sure weights are on device + + for b in blocks[self.num_blocks - self.blocks_to_swap :]: + b.to(self.device) # move block to device first + weighs_to_device(b, "cpu") # make sure weights are on cpu + + synchronize_device(self.device) + clean_memory_on_device(self.device) + + def wait_for_block(self, block_idx: int): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self._wait_blocks_move(block_idx) + + def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + if block_idx >= self.blocks_to_swap: + return + block_idx_to_cpu = block_idx + block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx + self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..41036d7083f4b053c5299bf44ce64cd999dd7123 --- /dev/null +++ b/library/custom_train_functions.py @@ -0,0 +1,556 @@ +import torch +import argparse +import random +import re +from typing import List, Optional, Union +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def prepare_scheduler_for_custom_training(noise_scheduler, device): + if hasattr(noise_scheduler, "all_snr"): + return + + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + alpha = sqrt_alphas_cumprod + sigma = sqrt_one_minus_alphas_cumprod + all_snr = (alpha / sigma) ** 2 + + noise_scheduler.all_snr = all_snr.to(device) + + +def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): + # fix beta: zero terminal SNR + logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891") + + def enforce_zero_terminal_snr(betas): + # Convert betas to alphas_bar_sqrt + alphas = 1 - betas + alphas_bar = alphas.cumprod(0) + alphas_bar_sqrt = alphas_bar.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + # Shift so last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + # Scale so first timestep is back to old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 + alphas = alphas_bar[1:] / alphas_bar[:-1] + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + return betas + + betas = noise_scheduler.betas + betas = enforce_zero_terminal_snr(betas) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + + # logger.info(f"original: {noise_scheduler.betas}") + # logger.info(f"fixed: {betas}") + + noise_scheduler.betas = betas + noise_scheduler.alphas = alphas + noise_scheduler.alphas_cumprod = alphas_cumprod + + +def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False): + snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) + min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) + if v_prediction: + snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device) + else: + snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device) + loss = loss * snr_weight + return loss + + +def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler): + scale = get_snr_scale(timesteps, noise_scheduler) + loss = loss * scale + return loss + + +def get_snr_scale(timesteps, noise_scheduler): + snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size + snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 + scale = snr_t / (snr_t + 1) + # # show debug info + # logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}") + return scale + + +def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss): + scale = get_snr_scale(timesteps, noise_scheduler) + # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") + loss = loss + loss / scale * v_pred_like_loss + return loss + + +def apply_debiased_estimation(loss, timesteps, noise_scheduler): + snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size + snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 + weight = 1 / torch.sqrt(snr_t) + loss = weight * loss + return loss + + +# TODO train_utilと分散しているのでどちらかに寄せる + + +def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True): + parser.add_argument( + "--min_snr_gamma", + type=float, + default=None, + help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨", + ) + parser.add_argument( + "--scale_v_pred_loss_like_noise_pred", + action="store_true", + help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする", + ) + parser.add_argument( + "--v_pred_like_loss", + type=float, + default=None, + help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する", + ) + parser.add_argument( + "--debiased_estimation_loss", + action="store_true", + help="debiased estimation loss / debiased estimation loss", + ) + if support_weighted_captions: + parser.add_argument( + "--weighted_captions", + action="store_true", + default=False, + help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意", + ) + + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = tokenizer(word).input_ids[1:-1] + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_unweighted_text_embeddings( + tokenizer, + text_encoder, + text_input: torch.Tensor, + chunk_length: int, + clip_skip: int, + eos: int, + pad: int, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + if pad == eos: # v1 + text_input_chunk[:, -1] = text_input[0, -1] + else: # v2 + for j in range(len(text_input_chunk)): + if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある + text_input_chunk[j, -1] = eos + if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD + text_input_chunk[j, 1] = eos + + if clip_skip is None or clip_skip == 1: + text_embedding = text_encoder(text_input_chunk)[0] + else: + enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) + text_embedding = enc_out["hidden_states"][-clip_skip] + text_embedding = text_encoder.text_model.final_layer_norm(text_embedding) + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) + else: + if clip_skip is None or clip_skip == 1: + text_embeddings = text_encoder(text_input)[0] + else: + enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) + text_embeddings = enc_out["hidden_states"][-clip_skip] + text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings) + return text_embeddings + + +def get_weighted_text_embeddings( + tokenizer, + text_encoder, + prompt: Union[str, List[str]], + device, + max_embeddings_multiples: Optional[int] = 3, + no_boseos_middle: Optional[bool] = False, + clip_skip=None, +): + r""" + Prompts can be assigned with local weights using brackets. For example, + prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', + and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. + + Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + no_boseos_middle (`bool`, *optional*, defaults to `False`): + If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and + ending token in each of the chunk in the middle. + skip_parsing (`bool`, *optional*, defaults to `False`): + Skip the parsing of brackets. + skip_weighting (`bool`, *optional*, defaults to `False`): + Skip the weighting. When the parsing is skipped, it is forced True. + """ + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] + + prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2) + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = tokenizer.bos_token_id + eos = tokenizer.eos_token_id + pad = tokenizer.pad_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + no_boseos_middle=no_boseos_middle, + chunk_length=tokenizer.model_max_length, + ) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) + + # get the embeddings + text_embeddings = get_unweighted_text_embeddings( + tokenizer, + text_encoder, + prompt_tokens, + tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device) + + # assign weights to the prompts and normalize in the sense of mean + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + return text_embeddings + + +# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2 +def pyramid_noise_like(noise, device, iterations=6, discount=0.4): + b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant! + u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device) + for i in range(iterations): + r = random.random() * 2 + 2 # Rather than always going 2x, + wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i))) + noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i + if wn == 1 or hn == 1: + break # Lowest resolution is 1x1 + return noise / noise.std() # Scaled back to roughly unit variance + + +# https://www.crosslabs.org//blog/diffusion-with-offset-noise +def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): + if noise_offset is None: + return noise + if adaptive_noise_scale is not None: + # latent shape: (batch_size, channels, height, width) + # abs mean value for each channel + latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True)) + + # multiply adaptive noise scale to the mean value and add it to the noise offset + noise_offset = noise_offset + adaptive_noise_scale * latent_mean + noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative + + noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + return noise + + +def apply_masked_loss(loss, batch): + if "conditioning_images" in batch: + # conditioning image is -1 to 1. we need to convert it to 0 to 1 + mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel + mask_image = mask_image / 2 + 0.5 + # print(f"conditioning_image: {mask_image.shape}") + elif "alpha_masks" in batch and batch["alpha_masks"] is not None: + # alpha mask is 0 to 1 + mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension + # print(f"mask_image: {mask_image.shape}, {mask_image.mean()}") + else: + return loss + + # resize to the same size as the loss + mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area") + loss = loss * mask_image + return loss + + +""" +########################################## +# Perlin Noise +def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3): + delta = (res[0] / shape[0], res[1] / shape[1]) + d = (shape[0] // res[0], shape[1] // res[1]) + + grid = ( + torch.stack( + torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)), + dim=-1, + ) + % 1 + ) + angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device) + gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) + + tile_grads = ( + lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]] + .repeat_interleave(d[0], 0) + .repeat_interleave(d[1], 1) + ) + dot = lambda grad, shift: ( + torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1) + * grad[: shape[0], : shape[1]] + ).sum(dim=-1) + + n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) + n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) + n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) + n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) + t = fade(grid[: shape[0], : shape[1]]) + return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) + + +def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5): + noise = torch.zeros(shape, device=device) + frequency = 1 + amplitude = 1 + for _ in range(octaves): + noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1])) + frequency *= 2 + amplitude *= persistence + return noise + + +def perlin_noise(noise, device, octaves): + _, c, w, h = noise.shape + perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves) + noise_perlin = [] + for _ in range(c): + noise_perlin.append(perlin()) + noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h) + noise += noise_perlin # broadcast for each batch + return noise / noise.std() # Scaled back to roughly unit variance +""" diff --git a/library/deepspeed_utils.py b/library/deepspeed_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..39583c7ac4838ee7cdad6be9de1f3dcf3ae12836 --- /dev/null +++ b/library/deepspeed_utils.py @@ -0,0 +1,139 @@ +import os +import argparse +import torch +from accelerate import DeepSpeedPlugin, Accelerator + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def add_deepspeed_arguments(parser: argparse.ArgumentParser): + # DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed + parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training") + parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.") + parser.add_argument( + "--offload_optimizer_device", + type=str, + default=None, + choices=[None, "cpu", "nvme"], + help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.", + ) + parser.add_argument( + "--offload_optimizer_nvme_path", + type=str, + default=None, + help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.", + ) + parser.add_argument( + "--offload_param_device", + type=str, + default=None, + choices=[None, "cpu", "nvme"], + help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.", + ) + parser.add_argument( + "--offload_param_nvme_path", + type=str, + default=None, + help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.", + ) + parser.add_argument( + "--zero3_init_flag", + action="store_true", + help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models." + "Only applicable with ZeRO Stage-3.", + ) + parser.add_argument( + "--zero3_save_16bit_model", + action="store_true", + help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.", + ) + parser.add_argument( + "--fp16_master_weights_and_gradients", + action="store_true", + help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.", + ) + + +def prepare_deepspeed_args(args: argparse.Namespace): + if not args.deepspeed: + return + + # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1. + args.max_data_loader_n_workers = 1 + + +def prepare_deepspeed_plugin(args: argparse.Namespace): + if not args.deepspeed: + return None + + try: + import deepspeed + except ImportError as e: + logger.error( + "deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed" + ) + exit(1) + + deepspeed_plugin = DeepSpeedPlugin( + zero_stage=args.zero_stage, + gradient_accumulation_steps=args.gradient_accumulation_steps, + gradient_clipping=args.max_grad_norm, + offload_optimizer_device=args.offload_optimizer_device, + offload_optimizer_nvme_path=args.offload_optimizer_nvme_path, + offload_param_device=args.offload_param_device, + offload_param_nvme_path=args.offload_param_nvme_path, + zero3_init_flag=args.zero3_init_flag, + zero3_save_16bit_model=args.zero3_save_16bit_model, + ) + deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size + deepspeed_plugin.deepspeed_config["train_batch_size"] = 1#( + # args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"]) + #) + deepspeed_plugin.set_mixed_precision(args.mixed_precision) + if args.mixed_precision.lower() == "fp16": + deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow. + if args.full_fp16 or args.fp16_master_weights_and_gradients: + if args.offload_optimizer_device == "cpu" and args.zero_stage == 2: + deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True + logger.info("[DeepSpeed] full fp16 enable.") + else: + logger.info( + "[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage." + ) + + if args.offload_optimizer_device is not None: + logger.info("[DeepSpeed] start to manually build cpu_adam.") + deepspeed.ops.op_builder.CPUAdamBuilder().load() + logger.info("[DeepSpeed] building cpu_adam done.") + + return deepspeed_plugin + + +# Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model. +def prepare_deepspeed_model(args: argparse.Namespace, **models): + # remove None from models + models = {k: v for k, v in models.items() if v is not None} + + class DeepSpeedWrapper(torch.nn.Module): + def __init__(self, **kw_models) -> None: + super().__init__() + self.models = torch.nn.ModuleDict() + + for key, model in kw_models.items(): + if isinstance(model, list): + model = torch.nn.ModuleList(model) + assert isinstance( + model, torch.nn.Module + ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}" + self.models.update(torch.nn.ModuleDict({key: model})) + + def get_models(self): + return self.models + + ds_model = DeepSpeedWrapper(**models) + return ds_model diff --git a/library/device_utils.py b/library/device_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e45ec7ada4082ab1c1273c4d5b419325e5132c50 --- /dev/null +++ b/library/device_utils.py @@ -0,0 +1,84 @@ +import functools +import gc + +import torch + +try: + HAS_CUDA = torch.cuda.is_available() +except Exception: + HAS_CUDA = False + +try: + HAS_MPS = torch.backends.mps.is_available() +except Exception: + HAS_MPS = False + +try: + import intel_extension_for_pytorch as ipex # noqa + + HAS_XPU = torch.xpu.is_available() +except Exception: + HAS_XPU = False + + +def clean_memory(): + gc.collect() + if HAS_CUDA: + torch.cuda.empty_cache() + if HAS_XPU: + torch.xpu.empty_cache() + if HAS_MPS: + torch.mps.empty_cache() + + +def clean_memory_on_device(device: torch.device): + r""" + Clean memory on the specified device, will be called from training scripts. + """ + gc.collect() + + # device may "cuda" or "cuda:0", so we need to check the type of device + if device.type == "cuda": + torch.cuda.empty_cache() + if device.type == "xpu": + torch.xpu.empty_cache() + if device.type == "mps": + torch.mps.empty_cache() + + +@functools.lru_cache(maxsize=None) +def get_preferred_device() -> torch.device: + r""" + Do not call this function from training scripts. Use accelerator.device instead. + """ + if HAS_CUDA: + device = torch.device("cuda") + elif HAS_XPU: + device = torch.device("xpu") + elif HAS_MPS: + device = torch.device("mps") + else: + device = torch.device("cpu") + print(f"get_preferred_device() -> {device}") + return device + + +def init_ipex(): + """ + Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`. + + This function should run right after importing torch and before doing anything else. + + If IPEX is not available, this function does nothing. + """ + try: + if HAS_XPU: + from .ipex import ipex_init + + is_initialized, error_message = ipex_init() + if not is_initialized: + print("failed to initialize ipex:", error_message) + else: + return + except Exception as e: + print("failed to initialize ipex:", e) diff --git a/library/flux_models.py b/library/flux_models.py new file mode 100644 index 0000000000000000000000000000000000000000..b2235d5888aa79183a5241da8bc03614d6eff826 --- /dev/null +++ b/library/flux_models.py @@ -0,0 +1,1060 @@ +# copy from FLUX repo: https://github.com/black-forest-labs/flux +# license: Apache-2.0 License + +from dataclasses import dataclass +import math +from typing import Dict, List, Optional, Union + +from .device_utils import init_ipex +from .custom_offloading_utils import ModelOffloader +init_ipex() + +import torch +from einops import rearrange +from torch import Tensor, nn +from torch.utils.checkpoint import checkpoint + +# USE_REENTRANT = True + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + + +# region autoencoder + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + ch: int + out_ch: int + ch_mult: list[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) + + +# endregion +# region config + + +@dataclass +class ModelSpec: + params: FluxParams + ae_params: AutoEncoderParams + ckpt_path: str | None + ae_path: str | None + # repo_id: str | None + # repo_flow: str | None + # repo_ae: str | None + + +configs = { + "dev": ModelSpec( + # repo_id="black-forest-labs/FLUX.1-dev", + # repo_flow="flux1-dev.sft", + # repo_ae="ae.sft", + ckpt_path=None, # os.getenv("FLUX_DEV"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=None, # os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "schnell": ModelSpec( + # repo_id="black-forest-labs/FLUX.1-schnell", + # repo_flow="flux1-schnell.sft", + # repo_ae="ae.sft", + ckpt_path=None, # os.getenv("FLUX_SCHNELL"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + ), + ae_path=None, # os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} + + +# endregion + +# region math + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor: + q, k = apply_rope(q, k, pe) + + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + x = rearrange(x, "B H L D -> B L (H D)") + + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + +# endregion + + +# region layers + + +# for cpu_offload_checkpointing + + +def to_cuda(x): + if isinstance(x, torch.Tensor): + return x.cuda() + elif isinstance(x, (list, tuple)): + return [to_cuda(elem) for elem in x] + elif isinstance(x, dict): + return {k: to_cuda(v) for k, v in x.items()} + else: + return x + + +def to_cpu(x): + if isinstance(x, torch.Tensor): + return x.cpu() + elif isinstance(x, (list, tuple)): + return [to_cpu(elem) for elem in x] + elif isinstance(x, dict): + return {k: to_cpu(v) for k, v in x.items()} + else: + return x + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + def _forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + # def forward(self, x): + # if self.training and self.gradient_checkpointing: + # def create_custom_forward(func): + # def custom_forward(*inputs): + # return func(*inputs) + # return custom_forward + # return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, use_reentrant=USE_REENTRANT) + # else: + # return self._forward(x) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + # return (x * rrms).to(dtype=x_dtype) * self.scale + return ((x * rrms) * self.scale.float()).to(dtype=x_dtype) + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + + # this is not called from DoubleStreamBlock/SingleStreamBlock because they uses attention function directly + def forward(self, x: Tensor, pe: Tensor) -> Tensor: + qkv = self.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = self.proj(x) + return x + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + + +class DoubleStreamBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def _forward( + self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None + ) -> tuple[Tensor, Tensor]: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + # make attention mask if not None + attn_mask = None + if txt_attention_mask is not None: + # F.scaled_dot_product_attention expects attn_mask to be bool for binary mask + attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len + attn_mask = torch.cat( + (attn_mask, torch.ones(attn_mask.shape[0], img.shape[1], device=attn_mask.device, dtype=torch.bool)), dim=1 + ) # b, seq_len + img_len + + # broadcast attn_mask to all heads + attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1) + + attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img blocks + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + + # calculate the txt blocks + txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + return img, txt + + def forward( + self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None + ) -> tuple[Tensor, Tensor]: + if self.training and self.gradient_checkpointing: + if not self.cpu_offload_checkpointing: + return checkpoint(self._forward, img, txt, vec, pe, txt_attention_mask, use_reentrant=False) + # cpu offload checkpointing + + def create_custom_forward(func): + def custom_forward(*inputs): + cuda_inputs = to_cuda(inputs) + outputs = func(*cuda_inputs) + return to_cpu(outputs) + + return custom_forward + + return torch.utils.checkpoint.checkpoint( + create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask, use_reentrant=False + ) + + else: + return self._forward(img, txt, vec, pe, txt_attention_mask) + + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor: + mod, _ = self.modulation(vec) + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + + # make attention mask if not None + attn_mask = None + if txt_attention_mask is not None: + # F.scaled_dot_product_attention expects attn_mask to be bool for binary mask + attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len + attn_mask = torch.cat( + ( + attn_mask, + torch.ones( + attn_mask.shape[0], x.shape[1] - txt_attention_mask.shape[1], device=attn_mask.device, dtype=torch.bool + ), + ), + dim=1, + ) # b, seq_len + img_len = x_len + + # broadcast attn_mask to all heads + attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1) + + # compute attention + attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) + + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + return x + mod.gate * output + + def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor: + if self.training and self.gradient_checkpointing: + if not self.cpu_offload_checkpointing: + return checkpoint(self._forward, x, vec, pe, txt_attention_mask, use_reentrant=False) + + # cpu offload checkpointing + + def create_custom_forward(func): + def custom_forward(*inputs): + cuda_inputs = to_cuda(inputs) + outputs = func(*cuda_inputs) + return to_cpu(outputs) + + return custom_forward + + return torch.utils.checkpoint.checkpoint( + create_custom_forward(self._forward), x, vec, pe, txt_attention_mask, use_reentrant=False + ) + else: + return self._forward(x, vec, pe, txt_attention_mask) + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x + + +# endregion + + +class Flux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + self.blocks_to_swap = None + + self.offloader_double = None + self.offloader_single = None + self.num_double_blocks = len(self.double_blocks) + self.num_single_blocks = len(self.single_blocks) + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + self.time_in.enable_gradient_checkpointing() + self.vector_in.enable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.enable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.enable_gradient_checkpointing(cpu_offload=cpu_offload) + + print(f"FLUX: Gradient checkpointing enabled. CPU offload: {cpu_offload}") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + self.time_in.disable_gradient_checkpointing() + self.vector_in.disable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.disable_gradient_checkpointing() + + for block in self.double_blocks + self.single_blocks: + block.disable_gradient_checkpointing() + + print("FLUX: Gradient checkpointing disabled.") + + def enable_block_swap(self, num_blocks: int, device: torch.device): + self.blocks_to_swap = num_blocks + double_blocks_to_swap = num_blocks // 2 + single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + + assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, ( + f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. " + f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks." + ) + + self.offloader_double = ModelOffloader( + self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True + ) + self.offloader_single = ModelOffloader( + self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True + ) + print( + f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." + ) + + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage + if self.blocks_to_swap: + save_double_blocks = self.double_blocks + save_single_blocks = self.single_blocks + self.double_blocks = None + self.single_blocks = None + + self.to(device) + + if self.blocks_to_swap: + self.double_blocks = save_double_blocks + self.single_blocks = save_single_blocks + + def prepare_block_swap_before_forward(self): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) + self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + txt_attention_mask: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + if not self.blocks_to_swap: + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + else: + for block_idx, block in enumerate(self.double_blocks): + self.offloader_double.wait_for_block(block_idx) + + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + + self.offloader_double.submit_move_blocks(self.double_blocks, block_idx) + + img = torch.cat((txt, img), 1) + + for block_idx, block in enumerate(self.single_blocks): + self.offloader_single.wait_for_block(block_idx) + + img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + + self.offloader_single.submit_move_blocks(self.single_blocks, block_idx) + + img = img[:, txt.shape[1] :, ...] + + if self.training and self.cpu_offload_checkpointing: + img = img.to(self.device) + vec = vec.to(self.device) + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + + return img diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f69c61d6fa8478e635415734d5f247ec130d476a --- /dev/null +++ b/library/flux_train_utils.py @@ -0,0 +1,585 @@ +import argparse +import math +import os +import numpy as np +import toml +import json +import time +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from accelerate import Accelerator, PartialState +from transformers import CLIPTextModel +from tqdm import tqdm +from PIL import Image + +from safetensors.torch import save_file +from . import flux_models, flux_utils, strategy_base, train_util +from .device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from .utils import setup_logging, mem_eff_save_file + +setup_logging() +import logging + +logger = logging.getLogger(__name__) +# from comfy.utils import ProgressBar + +def sample_images( + accelerator: Accelerator, + args: argparse.Namespace, + epoch, + steps, + flux, + ae, + text_encoders, + sample_prompts_te_outputs, + validation_settings=None, + prompt_replacement=None, +): + + logger.info("") + logger.info(f"generating sample images at step: {steps}") + + #distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + + # unwrap unet and text_encoder(s) + flux = accelerator.unwrap_model(flux) + if text_encoders is not None: + text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] + # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) + + prompts = [] + for line in args.sample_prompts: + line = line.strip() + if len(line) > 0 and line[0] != "#": + prompts.append(line) + + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + from .train_util import line_to_prompt_dict + + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass + + with torch.no_grad(), accelerator.autocast(): + image_tensor_list = [] + for prompt_dict in prompts: + image_tensor = sample_image_inference( + accelerator, + args, + flux, + text_encoders, + ae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + validation_settings + ) + image_tensor_list.append(image_tensor) + + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + clean_memory_on_device(accelerator.device) + return torch.cat(image_tensor_list, dim=0) + + +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + flux: flux_models.Flux, + text_encoders: Optional[List[CLIPTextModel]], + ae: flux_models.AutoEncoder, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + validation_settings=None +): + assert isinstance(prompt_dict, dict) + # negative_prompt = prompt_dict.get("negative_prompt") + if validation_settings is not None: + sample_steps = validation_settings["steps"] + width = validation_settings["width"] + height = validation_settings["height"] + scale = validation_settings["guidance_scale"] + seed = validation_settings["seed"] + base_shift = validation_settings["base_shift"] + max_shift = validation_settings["max_shift"] + shift = validation_settings["shift"] + else: + sample_steps = prompt_dict.get("sample_steps", 20) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 3.5) + seed = prompt_dict.get("seed") + base_shift = 0.5 + max_shift = 1.15 + shift = True + # controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + # if negative_prompt is not None: + # negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() + + # if negative_prompt is None: + # negative_prompt = "" + + height = max(64, height - height % 16) # round to divisible by 16 + width = max(64, width - width % 16) # round to divisible by 16 + logger.info(f"prompt: {prompt}") + # logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {scale}") + # logger.info(f"sample_sampler: {sampler_name}") + if seed is not None: + logger.info(f"seed: {seed}") + + # encode prompts + tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + + text_encoder_conds = [] + if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: + text_encoder_conds = sample_prompts_te_outputs[prompt] + print(f"Using cached text encoder outputs for prompt: {prompt}") + if text_encoders is not None: + print(f"Encoding prompt: {prompt}") + tokens_and_masks = tokenize_strategy.tokenize(prompt) + # strategy has apply_t5_attn_mask option + encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) + + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] + + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + + # sample image + weight_dtype = ae.dtype # TOFO give dtype as argument + packed_latent_height = height // 16 + packed_latent_width = width // 16 + noise = torch.randn( + 1, + packed_latent_height * packed_latent_width, + 16 * 2 * 2, + device=accelerator.device, + dtype=weight_dtype, + generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None, + ) + timesteps = get_schedule(sample_steps, noise.shape[1], base_shift=base_shift, max_shift=max_shift, shift=shift) # FLUX.1 dev -> shift=True + #print("TIMESTEPS: ", timesteps) + img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) + t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None + + with accelerator.autocast(), torch.no_grad(): + x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask) + + x = x.float() + x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) + + # latent to image + clean_memory_on_device(accelerator.device) + org_vae_device = ae.device # will be on cpu + ae.to(accelerator.device) # distributed_state.device is same as accelerator.device + with accelerator.autocast(), torch.no_grad(): + x = ae.decode(x) + ae.to(org_vae_device) + clean_memory_on_device(accelerator.device) + + x = x.clamp(-1, 1) + x = x.permute(0, 2, 3, 1) + image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) + + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) + return x + + # wandb有効時のみログを送信 + # try: + # wandb_tracker = accelerator.get_tracker("wandb") + # try: + # import wandb + # except ImportError: # 事前に一度確認するのでここはエラー出ないはず + # raise ImportError("No wandb / wandb がインストールされていないようです") + + # wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) + # except: # wandb 無効時 + # pass + + +def time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: flux_models.Flux, + img: torch.Tensor, + img_ids: torch.Tensor, + txt: torch.Tensor, + txt_ids: torch.Tensor, + vec: torch.Tensor, + timesteps: list[float], + guidance: float = 4.0, + t5_attn_mask: Optional[torch.Tensor] = None, +): + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + # comfy_pbar = ProgressBar(total=len(timesteps)) + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + model.prepare_block_swap_before_forward() + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + img = img + (t_prev - t_curr) * pred + # comfy_pbar.update(1) + model.prepare_block_swap_before_forward() + return img + +# endregion + + +# region train +def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None +): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + +def get_noisy_model_input_and_timesteps( + args, noise_scheduler, latents, noise, device, dtype +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz, _, H, W = latents.shape + sigmas = None + + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": + # Simple random t-based noise sampling + if args.timestep_sampling == "sigmoid": + # https://github.com/XLabs-AI/x-flux/tree/main + t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) + else: + t = torch.rand((bsz,), device=device) + + timesteps = t * 1000.0 + t = t.view(-1, 1, 1, 1) + noisy_model_input = (1 - t) * latents + t * noise + elif args.timestep_sampling == "shift": + shift = args.discrete_flow_shift + logits_norm = torch.randn(bsz, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling + timesteps = logits_norm.sigmoid() + timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) + + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 + noisy_model_input = (1 - t) * latents + t * noise + elif args.timestep_sampling == "flux_shift": + logits_norm = torch.randn(bsz, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling + timesteps = logits_norm.sigmoid() + mu=get_lin_function(y1=0.5, y2=1.15)((H//2) * (W//2)) + timesteps = time_shift(mu, 1.0, timesteps) + + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 + noisy_model_input = (1 - t) * latents + t * noise + else: + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + timesteps = noise_scheduler.timesteps[indices].to(device=device) + + # Add noise according to flow matching. + sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + return noisy_model_input, timesteps, sigmas + + +def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): + weighting = None + if args.model_prediction_type == "raw": + pass + elif args.model_prediction_type == "additive": + # add the model_pred to the noisy_model_input + model_pred = model_pred + noisy_model_input + elif args.model_prediction_type == "sigma_scaled": + # apply sigma scaling + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + return model_pred, weighting + + +def save_models( + ckpt_path: str, + flux: flux_models.Flux, + sai_metadata: Optional[dict], + save_dtype: Optional[torch.dtype] = None, + use_mem_eff_save: bool = False, +): + state_dict = {} + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + if save_dtype is not None and v.dtype != save_dtype: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + update_sd("", flux.state_dict()) + + if not use_mem_eff_save: + save_file(state_dict, ckpt_path, metadata=sai_metadata) + else: + mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata) + + +def save_flux_model_on_train_end( + args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, flux: flux_models.Flux +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev") + save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save) + + train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) + + +# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している +# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 +def save_flux_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + flux: flux_models.Flux, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev") + save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save) + + train_util.save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + True, + True, + epoch, + num_train_epochs, + global_step, + sd_saver, + None, + ) + + +# endregion + + +def add_flux_train_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--clip_l", + type=str, + help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)、float16が前提", + ) + parser.add_argument( + "--t5xxl", + type=str, + help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提", + ) + parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)") + parser.add_argument( + "--t5xxl_max_token_length", + type=int, + default=None, + help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev" + " / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512", + ) + parser.add_argument( + "--apply_t5_attn_mask", + action="store_true", + help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + + parser.add_argument( + "--timestep_sampling", + choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"], + default="sigma", + help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting." + " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。", + ) + parser.add_argument( + "--sigmoid_scale", + type=float, + default=1.0, + help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。', + ) + parser.add_argument( + "--model_prediction_type", + choices=["raw", "additive", "sigma_scaled"], + default="sigma_scaled", + help="How to interpret and process the model prediction: " + "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)." + " / モデル予測の解釈と処理方法:" + "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。", + ) + parser.add_argument( + "--discrete_flow_shift", + type=float, + default=3.0, + help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", + ) + parser.add_argument( + "--bypass_flux_guidance" + , action="store_true" + , help="bypass flux guidance module for Flex.1-Alpha Training" + ) diff --git a/library/flux_utils.py b/library/flux_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..62c58bd10119b78779097fdd6e4dbce43deba9c5 --- /dev/null +++ b/library/flux_utils.py @@ -0,0 +1,474 @@ +from dataclasses import replace +import json +import os +from typing import List, Optional, Tuple, Union +import einops +import torch + +from safetensors import safe_open +from accelerate import init_empty_weights +from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config + +from .flux_models import Flux, AutoEncoder, configs +from .utils import setup_logging, load_safetensors + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +MODEL_VERSION_FLUX_V1 = "flux1" +MODEL_NAME_DEV = "dev" +MODEL_NAME_SCHNELL = "schnell" + +# bypass guidance +def bypass_flux_guidance(transformer): + transformer.params.guidance_embed = False + +# restore the forward function +def restore_flux_guidance(transformer): + transformer.params.guidance_embed = True + +def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: + """ + チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。 + + Args: + ckpt_path (str): チェックポイントファイルまたはディレクトリのパス。 + + Returns: + Tuple[bool, bool, Tuple[int, int], List[str]]: + - bool: Diffusersかどうかを示すフラグ。 + - bool: Schnellかどうかを示すフラグ。 + - Tuple[int, int]: ダブルブロックとシングルブロックの数。 + - List[str]: チェックポイントに含まれるキーのリスト。 + """ + # check the state dict: Diffusers or BFL, dev or schnell, number of blocks + logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell") + + if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers + ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors") + if "00001-of-00003" in ckpt_path: + ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)] + else: + ckpt_paths = [ckpt_path] + + keys = [] + for ckpt_path in ckpt_paths: + with safe_open(ckpt_path, framework="pt") as f: + keys.extend(f.keys()) + + if keys[0].startswith("model.diffusion_model."): + keys = [key.replace("model.diffusion_model.", "") for key in keys] + + is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys + is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys) + + # check number of double and single blocks + if not is_diffusers: + max_double_block_index = max( + [int(key.split(".")[1]) for key in keys if key.startswith("double_blocks.") and key.endswith(".img_attn.proj.bias")] + ) + max_single_block_index = max( + [int(key.split(".")[1]) for key in keys if key.startswith("single_blocks.") and key.endswith(".modulation.lin.bias")] + ) + else: + max_double_block_index = max( + [ + int(key.split(".")[1]) + for key in keys + if key.startswith("transformer_blocks.") and key.endswith(".attn.add_k_proj.bias") + ] + ) + max_single_block_index = max( + [ + int(key.split(".")[1]) + for key in keys + if key.startswith("single_transformer_blocks.") and key.endswith(".attn.to_k.bias") + ] + ) + + num_double_blocks = max_double_block_index + 1 + num_single_blocks = max_single_block_index + 1 + + return is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths + + +def load_flow_model( + ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False +) -> Tuple[bool, Flux]: + is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) + name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL + + # build model + logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint") + with torch.device("meta"): + params = configs[name].params + + # set the number of blocks + if params.depth != num_double_blocks: + logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}") + params = replace(params, depth=num_double_blocks) + if params.depth_single_blocks != num_single_blocks: + logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}") + params = replace(params, depth_single_blocks=num_single_blocks) + + model = Flux(params) + if dtype is not None: + model = model.to(dtype) + + # load_sft doesn't support torch.device + logger.info(f"Loading state dict from {ckpt_path}") + sd = {} + for ckpt_path in ckpt_paths: + sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)) + + # convert Diffusers to BFL + if is_diffusers: + logger.info("Converting Diffusers to BFL") + sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks) + logger.info("Converted Diffusers to BFL") + + for key in list(sd.keys()): + new_key = key.replace("model.diffusion_model.", "") + if new_key == key: + break + sd[new_key] = sd.pop(key) + + info = model.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded Flux: {info}") + return is_schnell, model + + +def load_ae( + ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False +) -> AutoEncoder: + logger.info("Building AutoEncoder") + with torch.device("meta"): + # dev and schnell have the same AE params + ae = AutoEncoder(configs[MODEL_NAME_DEV].ae_params).to(dtype) + + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = ae.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded AE: {info}") + return ae + + +def load_clip_l( + ckpt_path: Optional[str], + dtype: torch.dtype, + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[dict] = None, +) -> CLIPTextModel: + logger.info("Building CLIP-L") + CLIPL_CONFIG = { + "_name_or_path": "clip-vit-large-patch14/", + "architectures": ["CLIPModel"], + "initializer_factor": 1.0, + "logit_scale_init_value": 2.6592, + "model_type": "clip", + "projection_dim": 768, + # "text_config": { + "_name_or_path": "", + "add_cross_attention": False, + "architectures": None, + "attention_dropout": 0.0, + "bad_words_ids": None, + "bos_token_id": 0, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": None, + "decoder_start_token_id": None, + "diversity_penalty": 0.0, + "do_sample": False, + "dropout": 0.0, + "early_stopping": False, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": 2, + "finetuning_task": None, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "hidden_act": "quick_gelu", + "hidden_size": 768, + "id2label": {"0": "LABEL_0", "1": "LABEL_1"}, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 3072, + "is_decoder": False, + "is_encoder_decoder": False, + "label2id": {"LABEL_0": 0, "LABEL_1": 1}, + "layer_norm_eps": 1e-05, + "length_penalty": 1.0, + "max_length": 20, + "max_position_embeddings": 77, + "min_length": 0, + "model_type": "clip_text_model", + "no_repeat_ngram_size": 0, + "num_attention_heads": 12, + "num_beam_groups": 1, + "num_beams": 1, + "num_hidden_layers": 12, + "num_return_sequences": 1, + "output_attentions": False, + "output_hidden_states": False, + "output_scores": False, + "pad_token_id": 1, + "prefix": None, + "problem_type": None, + "projection_dim": 768, + "pruned_heads": {}, + "remove_invalid_values": False, + "repetition_penalty": 1.0, + "return_dict": True, + "return_dict_in_generate": False, + "sep_token_id": None, + "task_specific_params": None, + "temperature": 1.0, + "tie_encoder_decoder": False, + "tie_word_embeddings": True, + "tokenizer_class": None, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": None, + "torchscript": False, + "transformers_version": "4.16.0.dev0", + "use_bfloat16": False, + "vocab_size": 49408, + "hidden_act": "gelu", + "hidden_size": 1280, + "intermediate_size": 5120, + "num_attention_heads": 20, + "num_hidden_layers": 32, + # }, + # "text_config_dict": { + "hidden_size": 768, + "intermediate_size": 3072, + "num_attention_heads": 12, + "num_hidden_layers": 12, + "projection_dim": 768, + # }, + # "torch_dtype": "float32", + # "transformers_version": None, + } + config = CLIPConfig(**CLIPL_CONFIG) + with init_empty_weights(): + clip = CLIPTextModel._from_config(config) + + if state_dict is not None: + sd = state_dict + else: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = clip.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded CLIP-L: {info}") + return clip + + +def load_t5xxl( + ckpt_path: str, + dtype: Optional[torch.dtype], + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[dict] = None, +) -> T5EncoderModel: + T5_CONFIG_JSON = """ +{ + "architectures": [ + "T5EncoderModel" + ], + "classifier_dropout": 0.0, + "d_ff": 10240, + "d_kv": 64, + "d_model": 4096, + "decoder_start_token_id": 0, + "dense_act_fn": "gelu_new", + "dropout_rate": 0.1, + "eos_token_id": 1, + "feed_forward_proj": "gated-gelu", + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": true, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "num_decoder_layers": 24, + "num_heads": 64, + "num_layers": 24, + "output_past": true, + "pad_token_id": 0, + "relative_attention_max_distance": 128, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.41.2", + "use_cache": true, + "vocab_size": 32128 +} +""" + config = json.loads(T5_CONFIG_JSON) + config = T5Config(**config) + with init_empty_weights(): + t5xxl = T5EncoderModel._from_config(config) + + if state_dict is not None: + sd = state_dict + else: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = t5xxl.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded T5xxl: {info}") + return t5xxl + + +def get_t5xxl_actual_dtype(t5xxl: T5EncoderModel) -> torch.dtype: + # nn.Embedding is the first layer, but it could be casted to bfloat16 or float32 + return t5xxl.encoder.block[0].layer[0].SelfAttention.q.weight.dtype + + +def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int): + img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width)[None, :] + img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size) + return img_ids + + +def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor: + """ + x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2) + return x + + +def pack_latents(x: torch.Tensor) -> torch.Tensor: + """ + x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2 + """ + x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + return x + + +# region Diffusers + +NUM_DOUBLE_BLOCKS = 19 +NUM_SINGLE_BLOCKS = 38 + +BFL_TO_DIFFUSERS_MAP = { + "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"], + "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"], + "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"], + "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"], + "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"], + "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"], + "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"], + "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"], + "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"], + "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"], + "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"], + "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"], + "txt_in.weight": ["context_embedder.weight"], + "txt_in.bias": ["context_embedder.bias"], + "img_in.weight": ["x_embedder.weight"], + "img_in.bias": ["x_embedder.bias"], + "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"], + "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"], + "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"], + "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"], + "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"], + "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"], + "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"], + "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"], + "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"], + "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"], + "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"], + "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"], + "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"], + "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"], + "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"], + "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"], + "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"], + "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"], + "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"], + "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"], + "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"], + "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"], + "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"], + "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"], + "single_blocks.().modulation.lin.weight": ["norm.linear.weight"], + "single_blocks.().modulation.lin.bias": ["norm.linear.bias"], + "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"], + "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"], + "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().linear2.bias": ["proj_out.bias"], + "final_layer.linear.weight": ["proj_out.weight"], + "final_layer.linear.bias": ["proj_out.bias"], + "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"], + "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"], +} + + +def make_diffusers_to_bfl_map(num_double_blocks: int, num_single_blocks: int) -> dict[str, tuple[int, str]]: + # make reverse map from diffusers map + diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key) + for b in range(num_double_blocks): + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if key.startswith("double_blocks."): + block_prefix = f"transformer_blocks.{b}." + for i, weight in enumerate(weights): + diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) + for b in range(num_single_blocks): + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if key.startswith("single_blocks."): + block_prefix = f"single_transformer_blocks.{b}." + for i, weight in enumerate(weights): + diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}")) + for key, weights in BFL_TO_DIFFUSERS_MAP.items(): + if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")): + for i, weight in enumerate(weights): + diffusers_to_bfl_map[weight] = (i, key) + return diffusers_to_bfl_map + + +def convert_diffusers_sd_to_bfl( + diffusers_sd: dict[str, torch.Tensor], num_double_blocks: int = NUM_DOUBLE_BLOCKS, num_single_blocks: int = NUM_SINGLE_BLOCKS +) -> dict[str, torch.Tensor]: + diffusers_to_bfl_map = make_diffusers_to_bfl_map(num_double_blocks, num_single_blocks) + + # iterate over three safetensors files to reduce memory usage + flux_sd = {} + for diffusers_key, tensor in diffusers_sd.items(): + if diffusers_key in diffusers_to_bfl_map: + index, bfl_key = diffusers_to_bfl_map[diffusers_key] + if bfl_key not in flux_sd: + flux_sd[bfl_key] = [] + flux_sd[bfl_key].append((index, tensor)) + else: + logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}") + raise KeyError(f"Key not found in diffusers_to_bfl_map: {diffusers_key}") + + # concat tensors if multiple tensors are mapped to a single key, sort by index + for key, values in flux_sd.items(): + if len(values) == 1: + flux_sd[key] = values[0][1] + else: + flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])]) + + # special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + if "final_layer.adaLN_modulation.1.weight" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"]) + if "final_layer.adaLN_modulation.1.bias" in flux_sd: + flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"]) + + return flux_sd + + +# endregion \ No newline at end of file diff --git a/library/huggingface_util.py b/library/huggingface_util.py new file mode 100644 index 0000000000000000000000000000000000000000..07a97606ef935a88a77756c287cfb57495afc5b0 --- /dev/null +++ b/library/huggingface_util.py @@ -0,0 +1,84 @@ +from typing import Union, BinaryIO +from huggingface_hub import HfApi +from pathlib import Path +import argparse +import os +from .utils import fire_in_thread +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) + +def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None): + api = HfApi( + token=token, + ) + try: + api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type) + return True + except: + return False + + +def upload( + args: argparse.Namespace, + src: Union[str, Path, bytes, BinaryIO], + dest_suffix: str = "", + force_sync_upload: bool = False, +): + repo_id = args.huggingface_repo_id + repo_type = args.huggingface_repo_type + token = args.huggingface_token + path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None + private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public" + api = HfApi(token=token) + if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token): + try: + api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private) + except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので + logger.error("===========================================") + logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}") + logger.error("===========================================") + + is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir()) + + def uploader(): + try: + if is_folder: + api.upload_folder( + repo_id=repo_id, + repo_type=repo_type, + folder_path=src, + path_in_repo=path_in_repo, + ) + else: + api.upload_file( + repo_id=repo_id, + repo_type=repo_type, + path_or_fileobj=src, + path_in_repo=path_in_repo, + ) + except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので + logger.error("===========================================") + logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}") + logger.error("===========================================") + + if args.async_upload and not force_sync_upload: + fire_in_thread(uploader) + else: + uploader() + + +def list_dir( + repo_id: str, + subfolder: str, + repo_type: str, + revision: str = "main", + token: str = None, +): + api = HfApi( + token=token, + ) + repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type) + file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)] + return file_list diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e5aba693c50450393872be2456dff8f1accabb3d --- /dev/null +++ b/library/ipex/__init__.py @@ -0,0 +1,180 @@ +import os +import sys +import contextlib +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +from .hijacks import ipex_hijacks + +# pylint: disable=protected-access, missing-function-docstring, line-too-long + +def ipex_init(): # pylint: disable=too-many-statements + try: + if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked: + return True, "Skipping IPEX hijack" + else: + # Replace cuda with xpu: + torch.cuda.current_device = torch.xpu.current_device + torch.cuda.current_stream = torch.xpu.current_stream + torch.cuda.device = torch.xpu.device + torch.cuda.device_count = torch.xpu.device_count + torch.cuda.device_of = torch.xpu.device_of + torch.cuda.get_device_name = torch.xpu.get_device_name + torch.cuda.get_device_properties = torch.xpu.get_device_properties + torch.cuda.init = torch.xpu.init + torch.cuda.is_available = torch.xpu.is_available + torch.cuda.is_initialized = torch.xpu.is_initialized + torch.cuda.is_current_stream_capturing = lambda: False + torch.cuda.set_device = torch.xpu.set_device + torch.cuda.stream = torch.xpu.stream + torch.cuda.synchronize = torch.xpu.synchronize + torch.cuda.Event = torch.xpu.Event + torch.cuda.Stream = torch.xpu.Stream + torch.cuda.FloatTensor = torch.xpu.FloatTensor + torch.Tensor.cuda = torch.Tensor.xpu + torch.Tensor.is_cuda = torch.Tensor.is_xpu + torch.nn.Module.cuda = torch.nn.Module.xpu + torch.UntypedStorage.cuda = torch.UntypedStorage.xpu + torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock + torch.cuda._initialized = torch.xpu.lazy_init._initialized + torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker + torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls + torch.cuda._tls = torch.xpu.lazy_init._tls + torch.cuda.threading = torch.xpu.lazy_init.threading + torch.cuda.traceback = torch.xpu.lazy_init.traceback + torch.cuda.Optional = torch.xpu.Optional + torch.cuda.__cached__ = torch.xpu.__cached__ + torch.cuda.__loader__ = torch.xpu.__loader__ + torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage + torch.cuda.Tuple = torch.xpu.Tuple + torch.cuda.streams = torch.xpu.streams + torch.cuda._lazy_new = torch.xpu._lazy_new + torch.cuda.FloatStorage = torch.xpu.FloatStorage + torch.cuda.Any = torch.xpu.Any + torch.cuda.__doc__ = torch.xpu.__doc__ + torch.cuda.default_generators = torch.xpu.default_generators + torch.cuda.HalfTensor = torch.xpu.HalfTensor + torch.cuda._get_device_index = torch.xpu._get_device_index + torch.cuda.__path__ = torch.xpu.__path__ + torch.cuda.Device = torch.xpu.Device + torch.cuda.IntTensor = torch.xpu.IntTensor + torch.cuda.ByteStorage = torch.xpu.ByteStorage + torch.cuda.set_stream = torch.xpu.set_stream + torch.cuda.BoolStorage = torch.xpu.BoolStorage + torch.cuda.os = torch.xpu.os + torch.cuda.torch = torch.xpu.torch + torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage + torch.cuda.Union = torch.xpu.Union + torch.cuda.DoubleTensor = torch.xpu.DoubleTensor + torch.cuda.ShortTensor = torch.xpu.ShortTensor + torch.cuda.LongTensor = torch.xpu.LongTensor + torch.cuda.IntStorage = torch.xpu.IntStorage + torch.cuda.LongStorage = torch.xpu.LongStorage + torch.cuda.__annotations__ = torch.xpu.__annotations__ + torch.cuda.__package__ = torch.xpu.__package__ + torch.cuda.__builtins__ = torch.xpu.__builtins__ + torch.cuda.CharTensor = torch.xpu.CharTensor + torch.cuda.List = torch.xpu.List + torch.cuda._lazy_init = torch.xpu._lazy_init + torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor + torch.cuda.DoubleStorage = torch.xpu.DoubleStorage + torch.cuda.ByteTensor = torch.xpu.ByteTensor + torch.cuda.StreamContext = torch.xpu.StreamContext + torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage + torch.cuda.ShortStorage = torch.xpu.ShortStorage + torch.cuda._lazy_call = torch.xpu._lazy_call + torch.cuda.HalfStorage = torch.xpu.HalfStorage + torch.cuda.random = torch.xpu.random + torch.cuda._device = torch.xpu._device + torch.cuda.classproperty = torch.xpu.classproperty + torch.cuda.__name__ = torch.xpu.__name__ + torch.cuda._device_t = torch.xpu._device_t + torch.cuda.warnings = torch.xpu.warnings + torch.cuda.__spec__ = torch.xpu.__spec__ + torch.cuda.BoolTensor = torch.xpu.BoolTensor + torch.cuda.CharStorage = torch.xpu.CharStorage + torch.cuda.__file__ = torch.xpu.__file__ + torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork + # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing + + # Memory: + torch.cuda.memory = torch.xpu.memory + if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): + torch.xpu.empty_cache = lambda: None + torch.cuda.empty_cache = torch.xpu.empty_cache + torch.cuda.memory_stats = torch.xpu.memory_stats + torch.cuda.memory_summary = torch.xpu.memory_summary + torch.cuda.memory_snapshot = torch.xpu.memory_snapshot + torch.cuda.memory_allocated = torch.xpu.memory_allocated + torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated + torch.cuda.memory_reserved = torch.xpu.memory_reserved + torch.cuda.memory_cached = torch.xpu.memory_reserved + torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved + torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved + torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats + torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict + torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats + + # RNG: + torch.cuda.get_rng_state = torch.xpu.get_rng_state + torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all + torch.cuda.set_rng_state = torch.xpu.set_rng_state + torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all + torch.cuda.manual_seed = torch.xpu.manual_seed + torch.cuda.manual_seed_all = torch.xpu.manual_seed_all + torch.cuda.seed = torch.xpu.seed + torch.cuda.seed_all = torch.xpu.seed_all + torch.cuda.initial_seed = torch.xpu.initial_seed + + # AMP: + torch.cuda.amp = torch.xpu.amp + torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled + torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype + + if not hasattr(torch.cuda.amp, "common"): + torch.cuda.amp.common = contextlib.nullcontext() + torch.cuda.amp.common.amp_definitely_not_available = lambda: False + + try: + torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler + except Exception: # pylint: disable=broad-exception-caught + try: + from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error + gradscaler_init() + torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler + except Exception: # pylint: disable=broad-exception-caught + torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler + + # C + torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream + ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count + ipex._C._DeviceProperties.major = 2024 + ipex._C._DeviceProperties.minor = 0 + + # Fix functions with ipex: + torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] + torch._utils._get_available_device_type = lambda: "xpu" + torch.has_cuda = True + torch.cuda.has_half = True + torch.cuda.is_bf16_supported = lambda *args, **kwargs: True + torch.cuda.is_fp16_supported = lambda *args, **kwargs: True + torch.backends.cuda.is_built = lambda *args, **kwargs: True + torch.version.cuda = "12.1" + torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1] + torch.cuda.get_device_properties.major = 12 + torch.cuda.get_device_properties.minor = 1 + torch.cuda.ipc_collect = lambda *args, **kwargs: None + torch.cuda.utilization = lambda *args, **kwargs: 0 + + ipex_hijacks() + if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None: + try: + from .diffusers import ipex_diffusers + ipex_diffusers() + except Exception: # pylint: disable=broad-exception-caught + pass + torch.cuda.is_xpu_hijacked = True + except Exception as e: + return False, e + return True, None diff --git a/library/ipex/attention.py b/library/ipex/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..2bc62f65c3b3bde29559814ee0c4a92d71a306f8 --- /dev/null +++ b/library/ipex/attention.py @@ -0,0 +1,177 @@ +import os +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +from functools import cache + +# pylint: disable=protected-access, missing-function-docstring, line-too-long + +# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers + +sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4)) +attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) + +# Find something divisible with the input_tokens +@cache +def find_slice_size(slice_size, slice_block_size): + while (slice_size * slice_block_size) > attention_slice_rate: + slice_size = slice_size // 2 + if slice_size <= 1: + slice_size = 1 + break + return slice_size + +# Find slice sizes for SDPA +@cache +def find_sdpa_slice_sizes(query_shape, query_element_size): + if len(query_shape) == 3: + batch_size_attention, query_tokens, shape_three = query_shape + shape_four = 1 + else: + batch_size_attention, query_tokens, shape_three, shape_four = query_shape + + slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size + block_size = batch_size_attention * slice_block_size + + split_slice_size = batch_size_attention + split_2_slice_size = query_tokens + split_3_slice_size = shape_three + + do_split = False + do_split_2 = False + do_split_3 = False + + if block_size > sdpa_slice_trigger_rate: + do_split = True + split_slice_size = find_slice_size(split_slice_size, slice_block_size) + if split_slice_size * slice_block_size > attention_slice_rate: + slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size + do_split_2 = True + split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) + if split_2_slice_size * slice_2_block_size > attention_slice_rate: + slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size + do_split_3 = True + split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) + + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + +# Find slice sizes for BMM +@cache +def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape): + batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2] + slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size + block_size = batch_size_attention * slice_block_size + + split_slice_size = batch_size_attention + split_2_slice_size = input_tokens + split_3_slice_size = mat2_atten_shape + + do_split = False + do_split_2 = False + do_split_3 = False + + if block_size > attention_slice_rate: + do_split = True + split_slice_size = find_slice_size(split_slice_size, slice_block_size) + if split_slice_size * slice_block_size > attention_slice_rate: + slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size + do_split_2 = True + split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) + if split_2_slice_size * slice_2_block_size > attention_slice_rate: + slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size + do_split_3 = True + split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) + + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + + +original_torch_bmm = torch.bmm +def torch_bmm_32_bit(input, mat2, *, out=None): + if input.device.type != "xpu": + return original_torch_bmm(input, mat2, out=out) + do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape) + + # Slice BMM + if do_split: + batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2] + hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype) + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size + if do_split_2: + for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + if do_split_3: + for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name + start_idx_3 = i3 * split_3_slice_size + end_idx_3 = (i3 + 1) * split_3_slice_size + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm( + input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + out=out + ) + else: + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm( + input[start_idx:end_idx, start_idx_2:end_idx_2], + mat2[start_idx:end_idx, start_idx_2:end_idx_2], + out=out + ) + else: + hidden_states[start_idx:end_idx] = original_torch_bmm( + input[start_idx:end_idx], + mat2[start_idx:end_idx], + out=out + ) + torch.xpu.synchronize(input.device) + else: + return original_torch_bmm(input, mat2, out=out) + return hidden_states + +original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention +def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs): + if query.device.type != "xpu": + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) + do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size()) + + # Slice SDPA + if do_split: + batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] + hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size + if do_split_2: + for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + if do_split_3: + for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name + start_idx_3 = i3 * split_3_slice_size + end_idx_3 = (i3 + 1) * split_3_slice_size + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention( + query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], + attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal, **kwargs + ) + else: + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( + query[start_idx:end_idx, start_idx_2:end_idx_2], + key[start_idx:end_idx, start_idx_2:end_idx_2], + value[start_idx:end_idx, start_idx_2:end_idx_2], + attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal, **kwargs + ) + else: + hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention( + query[start_idx:end_idx], + key[start_idx:end_idx], + value[start_idx:end_idx], + attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask, + dropout_p=dropout_p, is_causal=is_causal, **kwargs + ) + torch.xpu.synchronize(query.device) + else: + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) + return hidden_states diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py new file mode 100644 index 0000000000000000000000000000000000000000..732a185689e9882d63be68b5c5d6ee6d82c74f71 --- /dev/null +++ b/library/ipex/diffusers.py @@ -0,0 +1,312 @@ +import os +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +import diffusers #0.24.0 # pylint: disable=import-error +from diffusers.models.attention_processor import Attention +from diffusers.utils import USE_PEFT_BACKEND +from functools import cache + +# pylint: disable=protected-access, missing-function-docstring, line-too-long + +attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) + +@cache +def find_slice_size(slice_size, slice_block_size): + while (slice_size * slice_block_size) > attention_slice_rate: + slice_size = slice_size // 2 + if slice_size <= 1: + slice_size = 1 + break + return slice_size + +@cache +def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None): + if len(query_shape) == 3: + batch_size_attention, query_tokens, shape_three = query_shape + shape_four = 1 + else: + batch_size_attention, query_tokens, shape_three, shape_four = query_shape + if slice_size is not None: + batch_size_attention = slice_size + + slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size + block_size = batch_size_attention * slice_block_size + + split_slice_size = batch_size_attention + split_2_slice_size = query_tokens + split_3_slice_size = shape_three + + do_split = False + do_split_2 = False + do_split_3 = False + + if query_device_type != "xpu": + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + + if block_size > attention_slice_rate: + do_split = True + split_slice_size = find_slice_size(split_slice_size, slice_block_size) + if split_slice_size * slice_block_size > attention_slice_rate: + slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size + do_split_2 = True + split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size) + if split_2_slice_size * slice_2_block_size > attention_slice_rate: + slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size + do_split_3 = True + split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size) + + return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size + +class SlicedAttnProcessor: # pylint: disable=too-few-public-methods + r""" + Processor for implementing sliced attention. + + Args: + slice_size (`int`, *optional*): + The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and + `attention_head_dim` must be a multiple of the `slice_size`. + """ + + def __init__(self, slice_size): + self.slice_size = slice_size + + def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, + encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches + + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + batch_size_attention, query_tokens, shape_three = query.shape + hidden_states = torch.zeros( + (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype + ) + + #################################################################### + # ARC GPUs can't allocate more than 4GB to a single block, Slice it: + _, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size) + + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size + if do_split_2: + for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + if do_split_3: + for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name + start_idx_3 = i3 * split_3_slice_size + end_idx_3 = (i3 + 1) * split_3_slice_size + + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]) + + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice + del attn_slice + else: + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) + + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice + del attn_slice + torch.xpu.synchronize(query.device) + else: + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + del attn_slice + #################################################################### + + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class AttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, + encoder_hidden_states=None, attention_mask=None, + temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches + + residual = hidden_states + + args = () if USE_PEFT_BACKEND else (scale,) + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states, *args) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states, *args) + value = attn.to_v(encoder_hidden_states, *args) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + #################################################################### + # ARC GPUs can't allocate more than 4GB to a single block, Slice it: + batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] + hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) + do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type) + + if do_split: + for i in range(batch_size_attention // split_slice_size): + start_idx = i * split_slice_size + end_idx = (i + 1) * split_slice_size + if do_split_2: + for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name + start_idx_2 = i2 * split_2_slice_size + end_idx_2 = (i2 + 1) * split_2_slice_size + if do_split_3: + for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name + start_idx_3 = i3 * split_3_slice_size + end_idx_3 = (i3 + 1) * split_3_slice_size + + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]) + + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice + del attn_slice + else: + query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2] + key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2] + attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2]) + + hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice + del attn_slice + else: + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + + attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) + del query_slice + del key_slice + del attn_mask_slice + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + del attn_slice + torch.xpu.synchronize(query.device) + else: + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + #################################################################### + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, *args) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +def ipex_diffusers(): + #ARC GPUs can't allocate more than 4GB to a single block: + diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor + diffusers.models.attention_processor.AttnProcessor = AttnProcessor diff --git a/library/ipex/gradscaler.py b/library/ipex/gradscaler.py new file mode 100644 index 0000000000000000000000000000000000000000..6eb56bc2b821e8530557f517ebeaafa141b763a6 --- /dev/null +++ b/library/ipex/gradscaler.py @@ -0,0 +1,183 @@ +from collections import defaultdict +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import + +# pylint: disable=protected-access, missing-function-docstring, line-too-long + +device_supports_fp64 = torch.xpu.has_fp64_dtype() +OptState = ipex.cpu.autocast._grad_scaler.OptState +_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator +_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state + +def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument + per_device_inv_scale = _MultiDeviceReplicator(inv_scale) + per_device_found_inf = _MultiDeviceReplicator(found_inf) + + # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. + # There could be hundreds of grads, so we'd like to iterate through them just once. + # However, we don't know their devices or dtypes in advance. + + # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict + # Google says mypy struggles with defaultdicts type annotations. + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] + # sync grad to master weight + if hasattr(optimizer, "sync_grad"): + optimizer.sync_grad() + with torch.no_grad(): + for group in optimizer.param_groups: + for param in group["params"]: + if param.grad is None: + continue + if (not allow_fp16) and param.grad.dtype == torch.float16: + raise ValueError("Attempting to unscale FP16 gradients.") + if param.grad.is_sparse: + # is_coalesced() == False means the sparse grad has values with duplicate indices. + # coalesce() deduplicates indices and adds all values that have the same index. + # For scaled fp16 values, there's a good chance coalescing will cause overflow, + # so we should check the coalesced _values(). + if param.grad.dtype is torch.float16: + param.grad = param.grad.coalesce() + to_unscale = param.grad._values() + else: + to_unscale = param.grad + + # -: is there a way to split by device and dtype without appending in the inner loop? + to_unscale = to_unscale.to("cpu") + per_device_and_dtype_grads[to_unscale.device][ + to_unscale.dtype + ].append(to_unscale) + + for _, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + core._amp_foreach_non_finite_check_and_unscale_( + grads, + per_device_found_inf.get("cpu"), + per_device_inv_scale.get("cpu"), + ) + + return per_device_found_inf._per_device_tensors + +def unscale_(self, optimizer): + """ + Divides ("unscales") the optimizer's gradient tensors by the scale factor. + :meth:`unscale_` is optional, serving cases where you need to + :ref:`modify or inspect gradients` + between the backward pass(es) and :meth:`step`. + If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. + Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: + ... + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + scaler.step(optimizer) + scaler.update() + Args: + optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. + .. warning:: + :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, + and only after all gradients for that optimizer's assigned parameters have been accumulated. + Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. + .. warning:: + :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. + """ + if not self._enabled: + return + + self._check_scale_growth_tracker("unscale_") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise + raise RuntimeError( + "unscale_() has already been called on this optimizer since the last update()." + ) + elif optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError("unscale_() is being called after step().") + + # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. + assert self._scale is not None + if device_supports_fp64: + inv_scale = self._scale.double().reciprocal().float() + else: + inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) + found_inf = torch.full( + (1,), 0.0, dtype=torch.float32, device=self._scale.device + ) + + optimizer_state["found_inf_per_device"] = self._unscale_grads_( + optimizer, inv_scale, found_inf, False + ) + optimizer_state["stage"] = OptState.UNSCALED + +def update(self, new_scale=None): + """ + Updates the scale factor. + If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` + to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, + the scale is multiplied by ``growth_factor`` to increase it. + Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not + used directly, it's used to fill GradScaler's internal scale tensor. So if + ``new_scale`` was a tensor, later in-place changes to that tensor will not further + affect the scale GradScaler uses internally.) + Args: + new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor. + .. warning:: + :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has + been invoked for all optimizers used this iteration. + """ + if not self._enabled: + return + + _scale, _growth_tracker = self._check_scale_growth_tracker("update") + + if new_scale is not None: + # Accept a new user-defined scale. + if isinstance(new_scale, float): + self._scale.fill_(new_scale) # type: ignore[union-attr] + else: + reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False." + assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined] + assert new_scale.numel() == 1, reason + assert new_scale.requires_grad is False, reason + self._scale.copy_(new_scale) # type: ignore[union-attr] + else: + # Consume shared inf/nan data collected from optimizers to update the scale. + # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. + found_infs = [ + found_inf.to(device="cpu", non_blocking=True) + for state in self._per_optimizer_states.values() + for found_inf in state["found_inf_per_device"].values() + ] + + assert len(found_infs) > 0, "No inf checks were recorded prior to update." + + found_inf_combined = found_infs[0] + if len(found_infs) > 1: + for i in range(1, len(found_infs)): + found_inf_combined += found_infs[i] + + to_device = _scale.device + _scale = _scale.to("cpu") + _growth_tracker = _growth_tracker.to("cpu") + + core._amp_update_scale_( + _scale, + _growth_tracker, + found_inf_combined, + self._growth_factor, + self._backoff_factor, + self._growth_interval, + ) + + _scale = _scale.to(to_device) + _growth_tracker = _growth_tracker.to(to_device) + # To prepare for next iteration, clear the data collected from optimizers this iteration. + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + +def gradscaler_init(): + torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler + torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_ + torch.xpu.amp.GradScaler.unscale_ = unscale_ + torch.xpu.amp.GradScaler.update = update + return torch.xpu.amp.GradScaler diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py new file mode 100644 index 0000000000000000000000000000000000000000..d3cef8276ea60ef8b481633b081896d460affd9b --- /dev/null +++ b/library/ipex/hijacks.py @@ -0,0 +1,313 @@ +import os +from functools import wraps +from contextlib import nullcontext +import torch +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +import numpy as np + +device_supports_fp64 = torch.xpu.has_fp64_dtype() + +# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return + +class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods + def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument + if isinstance(device_ids, list) and len(device_ids) > 1: + print("IPEX backend doesn't support DataParallel on multiple XPU devices") + return module.to("xpu") + +def return_null_context(*args, **kwargs): # pylint: disable=unused-argument + return nullcontext() + +@property +def is_cuda(self): + return self.device.type == 'xpu' or self.device.type == 'cuda' + +def check_device(device): + return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int)) + +def return_xpu(device): + return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu" + + +# Autocast +original_autocast_init = torch.amp.autocast_mode.autocast.__init__ +@wraps(torch.amp.autocast_mode.autocast.__init__) +def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None): + if device_type == "cuda": + return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled) + else: + return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled) + +# Latent Antialias CPU Offload: +original_interpolate = torch.nn.functional.interpolate +@wraps(torch.nn.functional.interpolate) +def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments + if antialias or align_corners is not None or mode == 'bicubic': + return_device = tensor.device + return_dtype = tensor.dtype + return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode, + align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype) + else: + return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode, + align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias) + + +# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit): +original_from_numpy = torch.from_numpy +@wraps(torch.from_numpy) +def from_numpy(ndarray): + if ndarray.dtype == float: + return original_from_numpy(ndarray.astype('float32')) + else: + return original_from_numpy(ndarray) + +original_as_tensor = torch.as_tensor +@wraps(torch.as_tensor) +def as_tensor(data, dtype=None, device=None): + if check_device(device): + device = return_xpu(device) + if isinstance(data, np.ndarray) and data.dtype == float and not ( + (isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)): + return original_as_tensor(data, dtype=torch.float32, device=device) + else: + return original_as_tensor(data, dtype=dtype, device=device) + + +if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None: + original_torch_bmm = torch.bmm + original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention +else: + # 32 bit attention workarounds for Alchemist: + try: + from .attention import torch_bmm_32_bit as original_torch_bmm + from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention + except Exception: # pylint: disable=broad-exception-caught + original_torch_bmm = torch.bmm + original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention + + +# Data Type Errors: +@wraps(torch.bmm) +def torch_bmm(input, mat2, *, out=None): + if input.dtype != mat2.dtype: + mat2 = mat2.to(input.dtype) + return original_torch_bmm(input, mat2, out=out) + +@wraps(torch.nn.functional.scaled_dot_product_attention) +def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): + if query.dtype != key.dtype: + key = key.to(dtype=query.dtype) + if query.dtype != value.dtype: + value = value.to(dtype=query.dtype) + if attn_mask is not None and query.dtype != attn_mask.dtype: + attn_mask = attn_mask.to(dtype=query.dtype) + return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) + +# A1111 FP16 +original_functional_group_norm = torch.nn.functional.group_norm +@wraps(torch.nn.functional.group_norm) +def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05): + if weight is not None and input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_group_norm(input, num_groups, weight=weight, bias=bias, eps=eps) + +# A1111 BF16 +original_functional_layer_norm = torch.nn.functional.layer_norm +@wraps(torch.nn.functional.layer_norm) +def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05): + if weight is not None and input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_layer_norm(input, normalized_shape, weight=weight, bias=bias, eps=eps) + +# Training +original_functional_linear = torch.nn.functional.linear +@wraps(torch.nn.functional.linear) +def functional_linear(input, weight, bias=None): + if input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_linear(input, weight, bias=bias) + +original_functional_conv2d = torch.nn.functional.conv2d +@wraps(torch.nn.functional.conv2d) +def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + if input.dtype != weight.data.dtype: + input = input.to(dtype=weight.data.dtype) + if bias is not None and bias.data.dtype != weight.data.dtype: + bias.data = bias.data.to(dtype=weight.data.dtype) + return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + +# A1111 Embedding BF16 +original_torch_cat = torch.cat +@wraps(torch.cat) +def torch_cat(tensor, *args, **kwargs): + if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): + return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) + else: + return original_torch_cat(tensor, *args, **kwargs) + +# SwinIR BF16: +original_functional_pad = torch.nn.functional.pad +@wraps(torch.nn.functional.pad) +def functional_pad(input, pad, mode='constant', value=None): + if mode == 'reflect' and input.dtype == torch.bfloat16: + return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16) + else: + return original_functional_pad(input, pad, mode=mode, value=value) + + +original_torch_tensor = torch.tensor +@wraps(torch.tensor) +def torch_tensor(data, *args, dtype=None, device=None, **kwargs): + if check_device(device): + device = return_xpu(device) + if not device_supports_fp64: + if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device): + if dtype == torch.float64: + dtype = torch.float32 + elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)): + dtype = torch.float32 + return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs) + +original_Tensor_to = torch.Tensor.to +@wraps(torch.Tensor.to) +def Tensor_to(self, device=None, *args, **kwargs): + if check_device(device): + return original_Tensor_to(self, return_xpu(device), *args, **kwargs) + else: + return original_Tensor_to(self, device, *args, **kwargs) + +original_Tensor_cuda = torch.Tensor.cuda +@wraps(torch.Tensor.cuda) +def Tensor_cuda(self, device=None, *args, **kwargs): + if check_device(device): + return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs) + else: + return original_Tensor_cuda(self, device, *args, **kwargs) + +original_Tensor_pin_memory = torch.Tensor.pin_memory +@wraps(torch.Tensor.pin_memory) +def Tensor_pin_memory(self, device=None, *args, **kwargs): + if device is None: + device = "xpu" + if check_device(device): + return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs) + else: + return original_Tensor_pin_memory(self, device, *args, **kwargs) + +original_UntypedStorage_init = torch.UntypedStorage.__init__ +@wraps(torch.UntypedStorage.__init__) +def UntypedStorage_init(*args, device=None, **kwargs): + if check_device(device): + return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs) + else: + return original_UntypedStorage_init(*args, device=device, **kwargs) + +original_UntypedStorage_cuda = torch.UntypedStorage.cuda +@wraps(torch.UntypedStorage.cuda) +def UntypedStorage_cuda(self, device=None, *args, **kwargs): + if check_device(device): + return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs) + else: + return original_UntypedStorage_cuda(self, device, *args, **kwargs) + +original_torch_empty = torch.empty +@wraps(torch.empty) +def torch_empty(*args, device=None, **kwargs): + if check_device(device): + return original_torch_empty(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_empty(*args, device=device, **kwargs) + +original_torch_randn = torch.randn +@wraps(torch.randn) +def torch_randn(*args, device=None, dtype=None, **kwargs): + if dtype == bytes: + dtype = None + if check_device(device): + return original_torch_randn(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_randn(*args, device=device, **kwargs) + +original_torch_ones = torch.ones +@wraps(torch.ones) +def torch_ones(*args, device=None, **kwargs): + if check_device(device): + return original_torch_ones(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_ones(*args, device=device, **kwargs) + +original_torch_zeros = torch.zeros +@wraps(torch.zeros) +def torch_zeros(*args, device=None, **kwargs): + if check_device(device): + return original_torch_zeros(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_zeros(*args, device=device, **kwargs) + +original_torch_linspace = torch.linspace +@wraps(torch.linspace) +def torch_linspace(*args, device=None, **kwargs): + if check_device(device): + return original_torch_linspace(*args, device=return_xpu(device), **kwargs) + else: + return original_torch_linspace(*args, device=device, **kwargs) + +original_torch_Generator = torch.Generator +@wraps(torch.Generator) +def torch_Generator(device=None): + if check_device(device): + return original_torch_Generator(return_xpu(device)) + else: + return original_torch_Generator(device) + +original_torch_load = torch.load +@wraps(torch.load) +def torch_load(f, map_location=None, *args, **kwargs): + if map_location is None: + map_location = "xpu" + if check_device(map_location): + return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs) + else: + return original_torch_load(f, *args, map_location=map_location, **kwargs) + + +# Hijack Functions: +def ipex_hijacks(): + torch.tensor = torch_tensor + torch.Tensor.to = Tensor_to + torch.Tensor.cuda = Tensor_cuda + torch.Tensor.pin_memory = Tensor_pin_memory + torch.UntypedStorage.__init__ = UntypedStorage_init + torch.UntypedStorage.cuda = UntypedStorage_cuda + torch.empty = torch_empty + torch.randn = torch_randn + torch.ones = torch_ones + torch.zeros = torch_zeros + torch.linspace = torch_linspace + torch.Generator = torch_Generator + torch.load = torch_load + + torch.backends.cuda.sdp_kernel = return_null_context + torch.nn.DataParallel = DummyDataParallel + torch.UntypedStorage.is_cuda = is_cuda + torch.amp.autocast_mode.autocast.__init__ = autocast_init + + torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention + torch.nn.functional.group_norm = functional_group_norm + torch.nn.functional.layer_norm = functional_layer_norm + torch.nn.functional.linear = functional_linear + torch.nn.functional.conv2d = functional_conv2d + torch.nn.functional.interpolate = interpolate + torch.nn.functional.pad = functional_pad + + torch.bmm = torch_bmm + torch.cat = torch_cat + if not device_supports_fp64: + torch.from_numpy = from_numpy + torch.as_tensor = as_tensor diff --git a/library/model_util.py b/library/model_util.py new file mode 100644 index 0000000000000000000000000000000000000000..7d6070339a9ef65df07069676543ec123945aca9 --- /dev/null +++ b/library/model_util.py @@ -0,0 +1,1356 @@ +# v1: split from train_db_fixed.py. +# v2: support safetensors + +import math +import os + +import torch +from .device_utils import init_ipex +init_ipex() + +import diffusers +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging +from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel +from safetensors.torch import load_file, save_file +from .original_unet import UNet2DConditionModel +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) + +# DiffUsers版StableDiffusionのモデルパラメータ +NUM_TRAIN_TIMESTEPS = 1000 +BETA_START = 0.00085 +BETA_END = 0.0120 + +UNET_PARAMS_MODEL_CHANNELS = 320 +UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] +UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] +UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32` +UNET_PARAMS_IN_CHANNELS = 4 +UNET_PARAMS_OUT_CHANNELS = 4 +UNET_PARAMS_NUM_RES_BLOCKS = 2 +UNET_PARAMS_CONTEXT_DIM = 768 +UNET_PARAMS_NUM_HEADS = 8 +# UNET_PARAMS_USE_LINEAR_PROJECTION = False + +VAE_PARAMS_Z_CHANNELS = 4 +VAE_PARAMS_RESOLUTION = 256 +VAE_PARAMS_IN_CHANNELS = 3 +VAE_PARAMS_OUT_CH = 3 +VAE_PARAMS_CH = 128 +VAE_PARAMS_CH_MULT = [1, 2, 4, 4] +VAE_PARAMS_NUM_RES_BLOCKS = 2 + +# V2 +V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] +V2_UNET_PARAMS_CONTEXT_DIM = 1024 +# V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True + +# Diffusersの設定を読み込むための参照モデル +DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5" +DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1" + + +# region StableDiffusion->Diffusersの変換コード +# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0) + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + if diffusers.__version__ < "0.17.0": + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") + + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") + + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") + + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + else: + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming + to them. It splits attention layers, and takes into account additional replacements + that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + reshaping = False + if diffusers.__version__ < "0.17.0": + if "proj_attn.weight" in new_path: + reshaping = True + else: + if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2: + reshaping = True + + if reshaping: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def linear_transformer_to_conv(checkpoint): + keys = list(checkpoint.keys()) + tf_keys = ["proj_in.weight", "proj_out.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in tf_keys: + if checkpoint[key].ndim == 2: + checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) + + +def convert_ldm_unet_checkpoint(v2, checkpoint, config): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + unet_key = "model.diffusion_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias") + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + # オリジナル: + # if ["conv.weight", "conv.bias"] in output_block_list.values(): + # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) + + # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが + for l in output_block_list.values(): + l.sort() + + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + # SDのv2では1*1のconv2dがlinearに変わっている + # 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要 + if v2 and not config.get("use_linear_projection", False): + linear_transformer_to_conv(new_checkpoint) + + return new_checkpoint + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + # if len(vae_state_dict) == 0: + # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict + # vae_state_dict = checkpoint + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)} + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)} + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + # unet_params = original_config.model.params.unet_config.params + + block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + config = dict( + sample_size=UNET_PARAMS_IMAGE_SIZE, + in_channels=UNET_PARAMS_IN_CHANNELS, + out_channels=UNET_PARAMS_OUT_CHANNELS, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, + cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, + attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, + # use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION, + ) + if v2 and use_linear_projection_in_v2: + config["use_linear_projection"] = True + + return config + + +def create_vae_diffusers_config(): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + # vae_params = original_config.model.params.first_stage_config.params.ddconfig + # _ = original_config.model.params.first_stage_config.params.embed_dim + block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = dict( + sample_size=VAE_PARAMS_RESOLUTION, + in_channels=VAE_PARAMS_IN_CHANNELS, + out_channels=VAE_PARAMS_OUT_CH, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + latent_channels=VAE_PARAMS_Z_CHANNELS, + layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, + ) + return config + + +def convert_ldm_clip_checkpoint_v1(checkpoint): + keys = list(checkpoint.keys()) + text_model_dict = {} + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + + # remove position_ids for newer transformer, which causes error :( + if "text_model.embeddings.position_ids" in text_model_dict: + text_model_dict.pop("text_model.embeddings.position_ids") + + return text_model_dict + + +def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): + # 嫌になるくらい違うぞ! + def convert_key(key): + if not key.startswith("cond_stage_model"): + return None + + # common conversion + key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.") + key = key.replace("cond_stage_model.model.", "text_model.") + + if "resblocks" in key: + # resblocks conversion + key = key.replace(".resblocks.", ".layers.") + if ".ln_" in key: + key = key.replace(".ln_", ".layer_norm") + elif ".mlp." in key: + key = key.replace(".c_fc.", ".fc1.") + key = key.replace(".c_proj.", ".fc2.") + elif ".attn.out_proj" in key: + key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") + elif ".attn.in_proj" in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in SD: {key}") + elif ".positional_embedding" in key: + key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") + elif ".text_projection" in key: + key = None # 使われない??? + elif ".logit_scale" in key: + key = None # 使われない??? + elif ".token_embedding" in key: + key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") + elif ".ln_final" in key: + key = key.replace(".ln_final", ".final_layer_norm") + return key + + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + # remove resblocks 23 + if ".resblocks.23." in key: + continue + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの変換 + for key in keys: + if ".resblocks.23." in key: + continue + if ".resblocks" in key and ".attn.in_proj_" in key: + # 三つに分割 + values = torch.chunk(checkpoint[key], 3) + + key_suffix = ".weight" if "weight" in key else ".bias" + key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.") + key_pfx = key_pfx.replace("_weight", "") + key_pfx = key_pfx.replace("_bias", "") + key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") + new_sd[key_pfx + "q_proj" + key_suffix] = values[0] + new_sd[key_pfx + "k_proj" + key_suffix] = values[1] + new_sd[key_pfx + "v_proj" + key_suffix] = values[2] + + # rename or add position_ids + ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids" + if ANOTHER_POSITION_IDS_KEY in new_sd: + # waifu diffusion v1.4 + position_ids = new_sd[ANOTHER_POSITION_IDS_KEY] + del new_sd[ANOTHER_POSITION_IDS_KEY] + else: + position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) + + new_sd["text_model.embeddings.position_ids"] = position_ids + return new_sd + + +# endregion + + +# region Diffusers->StableDiffusion の変換コード +# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0) + + +def conv_transformer_to_linear(checkpoint): + keys = list(checkpoint.keys()) + tf_keys = ["proj_in.weight", "proj_out.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in tf_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + + +def convert_unet_state_dict_to_sd(v2, unet_state_dict): + unet_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("out.0.weight", "conv_norm_out.weight"), + ("out.0.bias", "conv_norm_out.bias"), + ("out.2.weight", "conv_out.weight"), + ("out.2.bias", "conv_out.bias"), + ] + + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut"), + ] + + unet_conversion_map_layer = [] + for i in range(4): + # loop over downblocks/upblocks + + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + if i > 0: + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + # buyer beware: this is a *brittle* function, + # and correct output requires that all of these pieces interact in + # the exact order in which I have arranged them. + mapping = {k: k for k in unet_state_dict.keys()} + for sd_name, hf_name in unet_conversion_map: + mapping[hf_name] = sd_name + for k, v in mapping.items(): + if "resnets" in k: + for sd_part, hf_part in unet_conversion_map_resnet: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + for sd_part, hf_part in unet_conversion_map_layer: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} + + if v2: + conv_transformer_to_linear(new_state_dict) + + return new_state_dict + + +def controlnet_conversion_map(): + unet_conversion_map = [ + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("middle_block_out.0.weight", "controlnet_mid_block.weight"), + ("middle_block_out.0.bias", "controlnet_mid_block.bias"), + ] + + unet_conversion_map_resnet = [ + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut"), + ] + + unet_conversion_map_layer = [] + for i in range(4): + for j in range(2): + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + if i < 3: + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + controlnet_cond_embedding_names = ["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"] + for i, hf_prefix in enumerate(controlnet_cond_embedding_names): + hf_prefix = f"controlnet_cond_embedding.{hf_prefix}." + sd_prefix = f"input_hint_block.{i*2}." + unet_conversion_map_layer.append((sd_prefix, hf_prefix)) + + for i in range(12): + hf_prefix = f"controlnet_down_blocks.{i}." + sd_prefix = f"zero_convs.{i}.0." + unet_conversion_map_layer.append((sd_prefix, hf_prefix)) + + return unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer + + +def convert_controlnet_state_dict_to_sd(controlnet_state_dict): + unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map() + + mapping = {k: k for k in controlnet_state_dict.keys()} + for sd_name, diffusers_name in unet_conversion_map: + mapping[diffusers_name] = sd_name + for k, v in mapping.items(): + if "resnets" in k: + for sd_part, diffusers_part in unet_conversion_map_resnet: + v = v.replace(diffusers_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + for sd_part, diffusers_part in unet_conversion_map_layer: + v = v.replace(diffusers_part, sd_part) + mapping[k] = v + new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()} + return new_state_dict + + +def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict): + unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map() + + mapping = {k: k for k in controlnet_state_dict.keys()} + for sd_name, diffusers_name in unet_conversion_map: + mapping[sd_name] = diffusers_name + for k, v in mapping.items(): + for sd_part, diffusers_part in unet_conversion_map_layer: + v = v.replace(sd_part, diffusers_part) + mapping[k] = v + for k, v in mapping.items(): + if "resnets" in v: + for sd_part, diffusers_part in unet_conversion_map_resnet: + v = v.replace(sd_part, diffusers_part) + mapping[k] = v + new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()} + return new_state_dict + + +# ================# +# VAE Conversion # +# ================# + + +def reshape_weight_for_sd(w): + # convert HF linear weights to SD conv2d weights + return w.reshape(*w.shape, 1, 1) + + +def convert_vae_state_dict(vae_state_dict): + vae_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("nin_shortcut", "conv_shortcut"), + ("norm_out", "conv_norm_out"), + ("mid.attn_1.", "mid_block.attentions.0."), + ] + + for i in range(4): + # down_blocks have two resnets + for j in range(2): + hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." + sd_down_prefix = f"encoder.down.{i}.block.{j}." + vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) + + if i < 3: + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." + sd_downsample_prefix = f"down.{i}.downsample." + vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) + + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"up.{3-i}.upsample." + vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) + + # up_blocks have three resnets + # also, up blocks in hf are numbered in reverse from sd + for j in range(3): + hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." + sd_up_prefix = f"decoder.up.{3-i}.block.{j}." + vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) + + # this part accounts for mid blocks in both the encoder and the decoder + for i in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{i}." + sd_mid_res_prefix = f"mid.block_{i+1}." + vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + if diffusers.__version__ < "0.17.0": + vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + ("q.", "query."), + ("k.", "key."), + ("v.", "value."), + ("proj_out.", "proj_attn."), + ] + else: + vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + ("q.", "to_q."), + ("k.", "to_k."), + ("v.", "to_v."), + ("proj_out.", "to_out.0."), + ] + + mapping = {k: k for k in vae_state_dict.keys()} + for k, v in mapping.items(): + for sd_part, hf_part in vae_conversion_map: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + if "attentions" in k: + for sd_part, hf_part in vae_conversion_map_attn: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} + weights_to_convert = ["q", "k", "v", "proj_out"] + for k, v in new_state_dict.items(): + for weight_name in weights_to_convert: + if f"mid.attn_1.{weight_name}.weight" in k: + # logger.info(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1") + new_state_dict[k] = reshape_weight_for_sd(v) + + return new_state_dict + + +# endregion + +# region 自作のモデル読み書きなど + + +def is_safetensors(path): + return os.path.splitext(path)[1].lower() == ".safetensors" + + +def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): + # text encoderの格納形式が違うモデルに対応する ('text_model'がない) + TEXT_ENCODER_KEY_REPLACEMENTS = [ + ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."), + ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."), + ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."), + ] + + if is_safetensors(ckpt_path): + checkpoint = None + state_dict = load_file(ckpt_path) # , device) # may causes error + else: + checkpoint = torch.load(ckpt_path, map_location=device) + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + checkpoint = None + + key_reps = [] + for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: + for key in state_dict.keys(): + if key.startswith(rep_from): + new_key = rep_to + key[len(rep_from) :] + key_reps.append((key, new_key)) + + for key, new_key in key_reps: + state_dict[new_key] = state_dict[key] + del state_dict[key] + + return checkpoint, state_dict + + +# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 +def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=True): + _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device) + + # Convert the UNet2DConditionModel model. + unet_config = create_unet_diffusers_config(v2, unet_use_linear_projection_in_v2) + converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config) + + unet = UNet2DConditionModel(**unet_config).to(device) + info = unet.load_state_dict(converted_unet_checkpoint) + logger.info(f"loading u-net: {info}") + + # Convert the VAE model. + vae_config = create_vae_diffusers_config() + converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) + + vae = AutoencoderKL(**vae_config).to(device) + info = vae.load_state_dict(converted_vae_checkpoint) + logger.info(f"loading vae: {info}") + + # convert text_model + if v2: + converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77) + cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=23, + num_attention_heads=16, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=512, + torch_dtype="float32", + transformers_version="4.25.0.dev0", + ) + text_model = CLIPTextModel._from_config(cfg) + info = text_model.load_state_dict(converted_text_encoder_checkpoint) + else: + converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) + + # logging.set_verbosity_error() # don't show annoying warning + # text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device) + # logging.set_verbosity_warning() + # logger.info(f"config: {text_model.config}") + cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=77, + hidden_act="quick_gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=768, + torch_dtype="float32", + ) + text_model = CLIPTextModel._from_config(cfg) + info = text_model.load_state_dict(converted_text_encoder_checkpoint) + logger.info(f"loading text encoder: {info}") + + return text_model, vae, unet + + +def get_model_version_str_for_sd1_sd2(v2, v_parameterization): + # only for reference + version_str = "sd" + if v2: + version_str += "_v2" + else: + version_str += "_v1" + if v_parameterization: + version_str += "_v" + return version_str + + +def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False): + def convert_key(key): + # position_idsの除去 + if ".position_ids" in key: + return None + + # common + key = key.replace("text_model.encoder.", "transformer.") + key = key.replace("text_model.", "") + if "layers" in key: + # resblocks conversion + key = key.replace(".layers.", ".resblocks.") + if ".layer_norm" in key: + key = key.replace(".layer_norm", ".ln_") + elif ".mlp." in key: + key = key.replace(".fc1.", ".c_fc.") + key = key.replace(".fc2.", ".c_proj.") + elif ".self_attn.out_proj" in key: + key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") + elif ".self_attn." in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in DiffUsers model: {key}") + elif ".position_embedding" in key: + key = key.replace("embeddings.position_embedding.weight", "positional_embedding") + elif ".token_embedding" in key: + key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") + elif "final_layer_norm" in key: + key = key.replace("final_layer_norm", "ln_final") + return key + + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの変換 + for key in keys: + if "layers" in key and "q_proj" in key: + # 三つを結合 + key_q = key + key_k = key.replace("q_proj", "k_proj") + key_v = key.replace("q_proj", "v_proj") + + value_q = checkpoint[key_q] + value_k = checkpoint[key_k] + value_v = checkpoint[key_v] + value = torch.cat([value_q, value_k, value_v]) + + new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") + new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") + new_sd[new_key] = value + + # 最後の層などを捏造するか + if make_dummy_weights: + logger.info("make dummy weights for resblock.23, text_projection and logit scale.") + keys = list(new_sd.keys()) + for key in keys: + if key.startswith("transformer.resblocks.22."): + new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる + + # Diffusersに含まれない重みを作っておく + new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) + new_sd["logit_scale"] = torch.tensor(1) + + return new_sd + + +def save_stable_diffusion_checkpoint( + v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, metadata, save_dtype=None, vae=None +): + if ckpt_path is not None: + # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む + checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) + if checkpoint is None: # safetensors または state_dictのckpt + checkpoint = {} + strict = False + else: + strict = True + if "state_dict" in state_dict: + del state_dict["state_dict"] + else: + # 新しく作る + assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint" + checkpoint = {} + state_dict = {} + strict = False + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + assert not strict or key in state_dict, f"Illegal key in save SD: {key}" + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + # Convert the UNet model + unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) + update_sd("model.diffusion_model.", unet_state_dict) + + # Convert the text encoder model + if v2: + make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる + text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy) + update_sd("cond_stage_model.model.", text_enc_dict) + else: + text_enc_dict = text_encoder.state_dict() + update_sd("cond_stage_model.transformer.", text_enc_dict) + + # Convert the VAE + if vae is not None: + vae_dict = convert_vae_state_dict(vae.state_dict()) + update_sd("first_stage_model.", vae_dict) + + # Put together new checkpoint + key_count = len(state_dict.keys()) + new_ckpt = {"state_dict": state_dict} + + # epoch and global_step are sometimes not int + try: + if "epoch" in checkpoint: + epochs += checkpoint["epoch"] + if "global_step" in checkpoint: + steps += checkpoint["global_step"] + except: + pass + + new_ckpt["epoch"] = epochs + new_ckpt["global_step"] = steps + + if is_safetensors(output_file): + # TODO Tensor以外のdictの値を削除したほうがいいか + save_file(state_dict, output_file, metadata) + else: + torch.save(new_ckpt, output_file) + + return key_count + + +def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False): + if pretrained_model_name_or_path is None: + # load default settings for v1/v2 + if v2: + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2 + else: + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1 + + scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") + if vae is None: + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") + + # original U-Net cannot be saved, so we need to convert it to the Diffusers version + # TODO this consumes a lot of memory + diffusers_unet = diffusers.UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet") + diffusers_unet.load_state_dict(unet.state_dict()) + + pipeline = StableDiffusionPipeline( + unet=diffusers_unet, + text_encoder=text_encoder, + vae=vae, + scheduler=scheduler, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=None, + ) + pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) + + +VAE_PREFIX = "first_stage_model." + + +def load_vae(vae_id, dtype): + logger.info(f"load VAE: {vae_id}") + if os.path.isdir(vae_id) or not os.path.isfile(vae_id): + # Diffusers local/remote + try: + vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) + except EnvironmentError as e: + logger.error(f"exception occurs in loading vae: {e}") + logger.error("retry with subfolder='vae'") + vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) + return vae + + # local + vae_config = create_vae_diffusers_config() + + if vae_id.endswith(".bin"): + # SD 1.5 VAE on Huggingface + converted_vae_checkpoint = torch.load(vae_id, map_location="cpu") + else: + # StableDiffusion + vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu") + vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model + + # vae only or full model + full_model = False + for vae_key in vae_sd: + if vae_key.startswith(VAE_PREFIX): + full_model = True + break + if not full_model: + sd = {} + for key, value in vae_sd.items(): + sd[VAE_PREFIX + key] = value + vae_sd = sd + del sd + + # Convert the VAE model. + converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + return vae + + +# endregion + + +def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64): + max_width, max_height = max_reso + max_area = max_width * max_height + + resos = set() + + width = int(math.sqrt(max_area) // divisible) * divisible + resos.add((width, width)) + + width = min_size + while width <= max_size: + height = min(max_size, int((max_area // width) // divisible) * divisible) + if height >= min_size: + resos.add((width, height)) + resos.add((height, width)) + + # # make additional resos + # if width >= height and width - divisible >= min_size: + # resos.add((width - divisible, height)) + # resos.add((height, width - divisible)) + # if height >= width and height - divisible >= min_size: + # resos.add((width, height - divisible)) + # resos.add((height - divisible, width)) + + width += divisible + + resos = list(resos) + resos.sort() + return resos + + +# if __name__ == "__main__": +# resos = make_bucket_resolutions((512, 768)) +# logger.info(f"{len(resos)}") +# logger.info(f"{resos}") +# aspect_ratios = [w / h for w, h in resos] +# logger.info(f"{aspect_ratios}") + +# ars = set() +# for ar in aspect_ratios: +# if ar in ars: +# logger.error(f"error! duplicate ar: {ar}") +# ars.add(ar) diff --git a/library/original_unet.py b/library/original_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..89d16ed5f9ecc15e2b15d1b3e92ebb9ed93e39f6 --- /dev/null +++ b/library/original_unet.py @@ -0,0 +1,1919 @@ +# Diffusers 0.10.2からStable Diffusionに必要な部分だけを持ってくる +# 条件分岐等で不要な部分は削除している +# コードの多くはDiffusersからコピーしている +# 制約として、モデルのstate_dictがDiffusers 0.10.2のものと同じ形式である必要がある + +# Copy from Diffusers 0.10.2 for Stable Diffusion. Most of the code is copied from Diffusers. +# Unnecessary parts are deleted by condition branching. +# As a constraint, the state_dict of the model must be in the same format as that of Diffusers 0.10.2 + +""" +v1.5とv2.1の相違点は +- attention_head_dimがintかlist[int]か +- cross_attention_dimが768か1024か +- use_linear_projection: trueがない(=False, 1.5)かあるか +- upcast_attentionがFalse(1.5)かTrue(2.1)か +- (以下は多分無視していい) +- sample_sizeが64か96か +- dual_cross_attentionがあるかないか +- num_class_embedsがあるかないか +- only_cross_attentionがあるかないか + +v1.5 +{ + "_class_name": "UNet2DConditionModel", + "_diffusers_version": "0.6.0", + "act_fn": "silu", + "attention_head_dim": 8, + "block_out_channels": [ + 320, + 640, + 1280, + 1280 + ], + "center_input_sample": false, + "cross_attention_dim": 768, + "down_block_types": [ + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D" + ], + "downsample_padding": 1, + "flip_sin_to_cos": true, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "norm_eps": 1e-05, + "norm_num_groups": 32, + "out_channels": 4, + "sample_size": 64, + "up_block_types": [ + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D" + ] +} + +v2.1 +{ + "_class_name": "UNet2DConditionModel", + "_diffusers_version": "0.10.0.dev0", + "act_fn": "silu", + "attention_head_dim": [ + 5, + 10, + 20, + 20 + ], + "block_out_channels": [ + 320, + 640, + 1280, + 1280 + ], + "center_input_sample": false, + "cross_attention_dim": 1024, + "down_block_types": [ + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D" + ], + "downsample_padding": 1, + "dual_cross_attention": false, + "flip_sin_to_cos": true, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_class_embeds": null, + "only_cross_attention": false, + "out_channels": 4, + "sample_size": 96, + "up_block_types": [ + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D" + ], + "use_linear_projection": true, + "upcast_attention": true +} +""" + +import math +from types import SimpleNamespace +from typing import Dict, Optional, Tuple, Union +import torch +from torch import nn +from torch.nn import functional as F +from einops import rearrange +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) + +BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280) +TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0] +TIME_EMBED_DIM = BLOCK_OUT_CHANNELS[0] * 4 +IN_CHANNELS: int = 4 +OUT_CHANNELS: int = 4 +LAYERS_PER_BLOCK: int = 2 +LAYERS_PER_BLOCK_UP: int = LAYERS_PER_BLOCK + 1 +TIME_EMBED_FLIP_SIN_TO_COS: bool = True +TIME_EMBED_FREQ_SHIFT: int = 0 +NORM_GROUPS: int = 32 +NORM_EPS: float = 1e-5 +TRANSFORMER_NORM_NUM_GROUPS = 32 + +DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"] +UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"] + + +# region memory efficient attention + +# FlashAttentionを使うCrossAttention +# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py +# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE + +# constants + +EPSILON = 1e-6 + +# helper functions + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +# flash attention forwards and backwards + +# https://arxiv.org/abs/2205.14135 + + +class FlashAttentionFunction(torch.autograd.Function): + @staticmethod + @torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """Algorithm 2 in the paper""" + + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) + + scale = q.shape[-1] ** -0.5 + + if not exists(mask): + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, "b n -> b 1 1 n") + mask = mask.split(q_bucket_size, dim=-1) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale + + if exists(row_mask): + attn_weights.masked_fill_(~row_mask, max_neg_value) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) + + if exists(row_mask): + exp_weights.masked_fill_(~row_mask, 0.0) + + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) + + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + + exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc) + + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) + + new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums + + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) + + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) + + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + + return o + + @staticmethod + @torch.no_grad() + def backward(ctx, do): + """Algorithm 4 in the paper""" + + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors + + device = q.device + + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + exp_attn_weights = torch.exp(attn_weights - mc) + + if exists(row_mask): + exp_attn_weights.masked_fill_(~row_mask, 0.0) + + p = exp_attn_weights / lc + + dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc) + dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc) + + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) + + dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc) + dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc) + + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) + + return dq, dk, dv, None, None, None, None + + +# endregion + + +def get_parameter_dtype(parameter: torch.nn.Module): + return next(parameter.parameters()).dtype + + +def get_parameter_device(parameter: torch.nn.Module): + return next(parameter.parameters()).device + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +# Deep Shrink: We do not common this function, because minimize dependencies. +def resize_like(x, target, mode="bicubic", align_corners=False): + org_dtype = x.dtype + if org_dtype == torch.bfloat16: + x = x.to(torch.float32) + + if x.shape[-2:] != target.shape[-2:]: + if mode == "nearest": + x = F.interpolate(x, size=target.shape[-2:], mode=mode) + else: + x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners) + + if org_dtype == torch.bfloat16: + x = x.to(org_dtype) + return x + + +class SampleOutput: + def __init__(self, sample): + self.sample = sample + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + self.act = None + if act_fn == "silu": + self.act = nn.SiLU() + elif act_fn == "mish": + self.act = nn.Mish() + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) + + def forward(self, sample): + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + return sample + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb + + +class ResnetBlock2D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.norm1 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=in_channels, eps=NORM_EPS, affine=True) + + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + self.time_emb_proj = torch.nn.Linear(TIME_EMBED_DIM, out_channels) + + self.norm2 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=out_channels, eps=NORM_EPS, affine=True) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + # if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + + self.use_in_shortcut = self.in_channels != self.out_channels + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +class DownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + add_downsample=True, + ): + super().__init__() + + self.has_cross_attention = False + resnets = [] + + for i in range(LAYERS_PER_BLOCK): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + ) + ) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = [Downsample2D(out_channels, out_channels=out_channels)] + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def set_use_memory_efficient_attention(self, xformers, mem_eff): + pass + + def set_use_sdpa(self, sdpa): + pass + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class Downsample2D(nn.Module): + def __init__(self, channels, out_channels): + super().__init__() + + self.channels = channels + self.out_channels = out_channels + + self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1) + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class CrossAttention(nn.Module): + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + upcast_attention: bool = False, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + # no dropout here + + self.use_memory_efficient_attention_xformers = False + self.use_memory_efficient_attention_mem_eff = False + self.use_sdpa = False + + # Attention processor + self.processor = None + + def set_use_memory_efficient_attention(self, xformers, mem_eff): + self.use_memory_efficient_attention_xformers = xformers + self.use_memory_efficient_attention_mem_eff = mem_eff + + def set_use_sdpa(self, sdpa): + self.use_sdpa = sdpa + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def set_processor(self): + return self.processor + + def get_processor(self): + return self.processor + + def forward(self, hidden_states, context=None, mask=None, **kwargs): + if self.processor is not None: + ( + hidden_states, + encoder_hidden_states, + attention_mask, + ) = translate_attention_names_from_diffusers( + hidden_states=hidden_states, context=context, mask=mask, **kwargs + ) + return self.processor( + attn=self, + hidden_states=hidden_states, + encoder_hidden_states=context, + attention_mask=mask, + **kwargs + ) + if self.use_memory_efficient_attention_xformers: + return self.forward_memory_efficient_xformers(hidden_states, context, mask) + if self.use_memory_efficient_attention_mem_eff: + return self.forward_memory_efficient_mem_eff(hidden_states, context, mask) + if self.use_sdpa: + return self.forward_sdpa(hidden_states, context, mask) + + query = self.to_q(hidden_states) + context = context if context is not None else hidden_states + key = self.to_k(context) + value = self.to_v(context) + + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + hidden_states = self._attention(query, key, value) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + # hidden_states = self.to_out[1](hidden_states) # no dropout + return hidden_states + + def _attention(self, query, key, value): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + # TODO support Hypernetworks + def forward_memory_efficient_xformers(self, x, context=None, mask=None): + import xformers.ops + + h = self.heads + q_in = self.to_q(x) + context = context if context is not None else x + context = context.to(x.dtype) + k_in = self.to_k(context) + v_in = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる + + out = rearrange(out, "b n h d -> b n (h d)", h=h) + + out = self.to_out[0](out) + return out + + def forward_memory_efficient_mem_eff(self, x, context=None, mask=None): + flash_func = FlashAttentionFunction + + q_bucket_size = 512 + k_bucket_size = 1024 + + h = self.heads + q = self.to_q(x) + context = context if context is not None else x + context = context.to(x.dtype) + k = self.to_k(context) + v = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + out = self.to_out[0](out) + return out + + def forward_sdpa(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) + context = context if context is not None else x + context = context.to(x.dtype) + k_in = self.to_k(context) + v_in = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + + out = rearrange(out, "b h n d -> b n (h d)", h=h) + + out = self.to_out[0](out) + return out + +def translate_attention_names_from_diffusers( + hidden_states: torch.FloatTensor, + context: Optional[torch.FloatTensor] = None, + mask: Optional[torch.FloatTensor] = None, + # HF naming + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None +): + # translate from hugging face diffusers + context = context if context is not None else encoder_hidden_states + + # translate from hugging face diffusers + mask = mask if mask is not None else attention_mask + + return hidden_states, context, mask + +# feedforward +class GEGLU(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + ): + super().__init__() + inner_dim = int(dim * 4) # mult is always 4 + + self.net = nn.ModuleList([]) + # project in + self.net.append(GEGLU(dim, inner_dim)) + # project dropout + self.net.append(nn.Identity()) # nn.Dropout(0)) # dummy for dropout with 0 + # project out + self.net.append(nn.Linear(inner_dim, dim)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False + ): + super().__init__() + + # 1. Self-Attn + self.attn1 = CrossAttention( + query_dim=dim, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + upcast_attention=upcast_attention, + ) + self.ff = FeedForward(dim) + + # 2. Cross-Attn + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + upcast_attention=upcast_attention, + ) + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim) + + def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool): + self.attn1.set_use_memory_efficient_attention(xformers, mem_eff) + self.attn2.set_use_memory_efficient_attention(xformers, mem_eff) + + def set_use_sdpa(self, sdpa: bool): + self.attn1.set_use_sdpa(sdpa) + self.attn2.set_use_sdpa(sdpa) + + def forward(self, hidden_states, context=None, timestep=None): + # 1. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + hidden_states = self.attn1(norm_hidden_states) + hidden_states + + # 2. Cross-Attention + norm_hidden_states = self.norm2(hidden_states) + hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states + + # 3. Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + return hidden_states + + +class Transformer2DModel(nn.Module): + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + use_linear_projection: bool = False, + upcast_attention: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.use_linear_projection = use_linear_projection + + self.norm = torch.nn.GroupNorm(num_groups=TRANSFORMER_NORM_NUM_GROUPS, num_channels=in_channels, eps=1e-6, affine=True) + + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + ) + ] + ) + + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def set_use_memory_efficient_attention(self, xformers, mem_eff): + for transformer in self.transformer_blocks: + transformer.set_use_memory_efficient_attention(xformers, mem_eff) + + def set_use_sdpa(self, sdpa): + for transformer in self.transformer_blocks: + transformer.set_use_sdpa(sdpa) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + # 1. Input + batch, _, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep) + + # 3. Output + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return SampleOutput(sample=output) + + +class CrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + add_downsample=True, + cross_attention_dim=1280, + attn_num_head_channels=1, + use_linear_projection=False, + upcast_attention=False, + ): + super().__init__() + self.has_cross_attention = True + resnets = [] + attentions = [] + + self.attn_num_head_channels = attn_num_head_channels + + for i in range(LAYERS_PER_BLOCK): + in_channels = in_channels if i == 0 else out_channels + + resnets.append(ResnetBlock2D(in_channels=in_channels, out_channels=out_channels)) + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def set_use_memory_efficient_attention(self, xformers, mem_eff): + for attn in self.attentions: + attn.set_use_memory_efficient_attention(xformers, mem_eff) + + def set_use_sdpa(self, sdpa): + for attn in self.attentions: + attn.set_use_sdpa(sdpa) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + attn_num_head_channels=1, + cross_attention_dim=1280, + use_linear_projection=False, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + # Middle block has two resnets and one attention + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + ), + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + ), + ] + attentions = [ + Transformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + ) + ] + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def set_use_memory_efficient_attention(self, xformers, mem_eff): + for attn in self.attentions: + attn.set_use_memory_efficient_attention(xformers, mem_eff) + + def set_use_sdpa(self, sdpa): + for attn in self.attentions: + attn.set_use_sdpa(sdpa) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + for i, resnet in enumerate(self.resnets): + attn = None if i == 0 else self.attentions[i - 1] + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + if attn is not None: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + if attn is not None: + hidden_states = attn(hidden_states, encoder_hidden_states).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class Upsample2D(nn.Module): + def __init__(self, channels, out_channels): + super().__init__() + self.channels = channels + self.out_channels = out_channels + self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, hidden_states, output_size): + assert hidden_states.shape[1] == self.channels + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch + # https://github.com/pytorch/pytorch/issues/86679 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class UpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + add_upsample=True, + ): + super().__init__() + + self.has_cross_attention = False + resnets = [] + + for i in range(LAYERS_PER_BLOCK_UP): + res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def set_use_memory_efficient_attention(self, xformers, mem_eff): + pass + + def set_use_sdpa(self, sdpa): + pass + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + attn_num_head_channels=1, + cross_attention_dim=1280, + add_upsample=True, + use_linear_projection=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(LAYERS_PER_BLOCK_UP): + res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + ) + ) + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def set_use_memory_efficient_attention(self, xformers, mem_eff): + for attn in self.attentions: + attn.set_use_memory_efficient_attention(xformers, mem_eff) + + def set_use_sdpa(self, sdpa): + for attn in self.attentions: + attn.set_use_sdpa(sdpa) + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +def get_down_block( + down_block_type, + in_channels, + out_channels, + add_downsample, + attn_num_head_channels, + cross_attention_dim, + use_linear_projection, + upcast_attention, +): + if down_block_type == "DownBlock2D": + return DownBlock2D( + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + ) + elif down_block_type == "CrossAttnDownBlock2D": + return CrossAttnDownBlock2D( + in_channels=in_channels, + out_channels=out_channels, + add_downsample=add_downsample, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + + +def get_up_block( + up_block_type, + in_channels, + out_channels, + prev_output_channel, + add_upsample, + attn_num_head_channels, + cross_attention_dim=None, + use_linear_projection=False, + upcast_attention=False, +): + if up_block_type == "UpBlock2D": + return UpBlock2D( + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + add_upsample=add_upsample, + ) + elif up_block_type == "CrossAttnUpBlock2D": + return CrossAttnUpBlock2D( + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + attn_num_head_channels=attn_num_head_channels, + cross_attention_dim=cross_attention_dim, + add_upsample=add_upsample, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + + +class UNet2DConditionModel(nn.Module): + _supports_gradient_checkpointing = True + + def __init__( + self, + sample_size: Optional[int] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + cross_attention_dim: int = 1280, + use_linear_projection: bool = False, + upcast_attention: bool = False, + **kwargs, + ): + super().__init__() + assert sample_size is not None, "sample_size must be specified" + logger.info( + f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}" + ) + + # 外部からの参照用に定義しておく + self.in_channels = IN_CHANNELS + self.out_channels = OUT_CHANNELS + + self.sample_size = sample_size + self.prepare_config(sample_size=sample_size) + + # state_dictの書式が変わるのでmoduleの持ち方は変えられない + + # input + self.conv_in = nn.Conv2d(IN_CHANNELS, BLOCK_OUT_CHANNELS[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(BLOCK_OUT_CHANNELS[0], TIME_EMBED_FLIP_SIN_TO_COS, TIME_EMBED_FREQ_SHIFT) + + self.time_embedding = TimestepEmbedding(TIMESTEP_INPUT_DIM, TIME_EMBED_DIM) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * 4 + + # down + output_channel = BLOCK_OUT_CHANNELS[0] + for i, down_block_type in enumerate(DOWN_BLOCK_TYPES): + input_channel = output_channel + output_channel = BLOCK_OUT_CHANNELS[i] + is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1 + + down_block = get_down_block( + down_block_type, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + attn_num_head_channels=attention_head_dim[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=BLOCK_OUT_CHANNELS[-1], + attn_num_head_channels=attention_head_dim[-1], + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(BLOCK_OUT_CHANNELS)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(UP_BLOCK_TYPES): + is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(BLOCK_OUT_CHANNELS) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + add_upsample=add_upsample, + attn_num_head_channels=reversed_attention_head_dim[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=BLOCK_OUT_CHANNELS[0], num_groups=NORM_GROUPS, eps=NORM_EPS) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1) + + # region diffusers compatibility + def prepare_config(self, *args, **kwargs): + self.config = SimpleNamespace(**kwargs) + + @property + def dtype(self) -> torch.dtype: + # `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + return get_parameter_dtype(self) + + @property + def device(self) -> torch.device: + # `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device). + return get_parameter_device(self) + + def set_attention_slice(self, slice_size): + raise NotImplementedError("Attention slicing is not supported for this model.") + + def is_gradient_checkpointing(self) -> bool: + return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) + + def enable_gradient_checkpointing(self): + self.set_gradient_checkpointing(value=True) + + def disable_gradient_checkpointing(self): + self.set_gradient_checkpointing(value=False) + + def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None: + modules = self.down_blocks + [self.mid_block] + self.up_blocks + for module in modules: + module.set_use_memory_efficient_attention(xformers, mem_eff) + + def set_use_sdpa(self, sdpa: bool) -> None: + modules = self.down_blocks + [self.mid_block] + self.up_blocks + for module in modules: + module.set_use_sdpa(sdpa) + + def set_gradient_checkpointing(self, value=False): + modules = self.down_blocks + [self.mid_block] + self.up_blocks + for module in modules: + logger.info(f"{module.__class__.__name__} {module.gradient_checkpointing} -> {value}") + module.gradient_checkpointing = value + + # endregion + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + return_dict: bool = True, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + ) -> Union[Dict, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a dict instead of a plain tuple. + + Returns: + `SampleOutput` or `tuple`: + `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある + # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する + # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + # 64で割り切れないときはupsamplerにサイズを伝える + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + # logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 1. time + timesteps = timestep + timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理 + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + # timestepsは重みを含まないので常にfloat32のテンソルを返す + # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある + # time_projでキャストしておけばいいんじゃね? + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + # 2. pre-process + sample = self.conv_in(sample) + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 + # まあこちらのほうがわかりやすいかもしれない + if downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # skip connectionにControlNetの出力を追加する + if down_block_additional_residuals is not None: + down_block_res_samples = list(down_block_res_samples) + for i in range(len(down_block_res_samples)): + down_block_res_samples[i] += down_block_additional_residuals[i] + down_block_res_samples = tuple(down_block_res_samples) + + # 4. mid + sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + + # ControlNetの出力を追加する + if mid_block_additional_residual is not None: + sample += mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection + + # if we have not reached the final block and need to forward the upsample size, we do it here + # 前述のように最後のブロック以外ではupsample_sizeを伝える + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return SampleOutput(sample=sample) + + def handle_unusual_timesteps(self, sample, timesteps): + r""" + timestampsがTensorでない場合、Tensorに変換する。またOnnx/Core MLと互換性のあるようにbatchサイズまでbroadcastする。 + """ + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timesteps, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + return timesteps + + +class InferUNet2DConditionModel: + def __init__(self, original_unet: UNet2DConditionModel): + self.delegate = original_unet + + # override original model's forward method: because forward is not called by `__call__` + # overriding `__call__` is not enough, because nn.Module.forward has a special handling + self.delegate.forward = self.forward + + # override original model's up blocks' forward method + for up_block in self.delegate.up_blocks: + if up_block.__class__.__name__ == "UpBlock2D": + + def resnet_wrapper(func, block): + def forward(*args, **kwargs): + return func(block, *args, **kwargs) + + return forward + + up_block.forward = resnet_wrapper(self.up_block_forward, up_block) + + elif up_block.__class__.__name__ == "CrossAttnUpBlock2D": + + def cross_attn_up_wrapper(func, block): + def forward(*args, **kwargs): + return func(block, *args, **kwargs) + + return forward + + up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block) + + # Deep Shrink + self.ds_depth_1 = None + self.ds_depth_2 = None + self.ds_timesteps_1 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + + # call original model's methods + def __getattr__(self, name): + return getattr(self.delegate, name) + + def __call__(self, *args, **kwargs): + return self.delegate(*args, **kwargs) + + def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): + if ds_depth_1 is None: + logger.info("Deep Shrink is disabled.") + self.ds_depth_1 = None + self.ds_timesteps_1 = None + self.ds_depth_2 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + else: + logger.info( + f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" + ) + self.ds_depth_1 = ds_depth_1 + self.ds_timesteps_1 = ds_timesteps_1 + self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1 + self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 + self.ds_ratio = ds_ratio + + def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in _self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # Deep Shrink + if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]: + hidden_states = resize_like(hidden_states, res_hidden_states) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + + if _self.upsamplers is not None: + for upsampler in _self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + def cross_attn_up_block_forward( + self, + _self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + ): + for resnet, attn in zip(_self.resnets, _self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # Deep Shrink + if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]: + hidden_states = resize_like(hidden_states, res_hidden_states) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + if _self.upsamplers is not None: + for upsampler in _self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + return_dict: bool = True, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + ) -> Union[Dict, Tuple]: + r""" + current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink. + """ + + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a dict instead of a plain tuple. + + Returns: + `SampleOutput` or `tuple`: + `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + + _self = self.delegate + + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある + # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する + # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い + default_overall_up_factor = 2**_self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + # 64で割り切れないときはupsamplerにサイズを伝える + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + # logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 1. time + timesteps = timestep + timesteps = _self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理 + + t_emb = _self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + # timestepsは重みを含まないので常にfloat32のテンソルを返す + # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある + # time_projでキャストしておけばいいんじゃね? + t_emb = t_emb.to(dtype=_self.dtype) + emb = _self.time_embedding(t_emb) + + # 2. pre-process + sample = _self.conv_in(sample) + + down_block_res_samples = (sample,) + for depth, downsample_block in enumerate(_self.down_blocks): + # Deep Shrink + if self.ds_depth_1 is not None: + if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or ( + self.ds_depth_2 is not None + and depth == self.ds_depth_2 + and timesteps[0] < self.ds_timesteps_1 + and timesteps[0] >= self.ds_timesteps_2 + ): + org_dtype = sample.dtype + if org_dtype == torch.bfloat16: + sample = sample.to(torch.float32) + sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype) + + # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 + # まあこちらのほうがわかりやすいかもしれない + if downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # skip connectionにControlNetの出力を追加する + if down_block_additional_residuals is not None: + down_block_res_samples = list(down_block_res_samples) + for i in range(len(down_block_res_samples)): + down_block_res_samples[i] += down_block_additional_residuals[i] + down_block_res_samples = tuple(down_block_res_samples) + + # 4. mid + sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + + # ControlNetの出力を追加する + if mid_block_additional_residual is not None: + sample += mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(_self.up_blocks): + is_final_block = i == len(_self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection + + # if we have not reached the final block and need to forward the upsample size, we do it here + # 前述のように最後のブロック以外ではupsample_sizeを伝える + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 6. post-process + sample = _self.conv_norm_out(sample) + sample = _self.conv_act(sample) + sample = _self.conv_out(sample) + + if not return_dict: + return (sample,) + + return SampleOutput(sample=sample) diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..f584fe8dad4d45958cece116f4aa1ff6a6ffc531 --- /dev/null +++ b/library/sai_model_spec.py @@ -0,0 +1,334 @@ +# based on https://github.com/Stability-AI/ModelSpec +import datetime +import hashlib +from io import BytesIO +import os +from typing import List, Optional, Tuple, Union +import safetensors +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +r""" +# Metadata Example +metadata = { + # === Must === + "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec + "modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID + "modelspec.implementation": "sgm", + "modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc + # === Should === + "modelspec.author": "Example Corp", # Your name or company name + "modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know + "modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created + # === Can === + "modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc. + "modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model +} +""" + +BASE_METADATA = { + # === Must === + "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec + "modelspec.architecture": None, + "modelspec.implementation": None, + "modelspec.title": None, + "modelspec.resolution": None, + # === Should === + "modelspec.description": None, + "modelspec.author": None, + "modelspec.date": None, + # === Can === + "modelspec.license": None, + "modelspec.tags": None, + "modelspec.merged_from": None, + "modelspec.prediction_type": None, + "modelspec.timestep_range": None, + "modelspec.encoder_layer": None, +} + +# 別に使うやつだけ定義 +MODELSPEC_TITLE = "modelspec.title" + +ARCH_SD_V1 = "stable-diffusion-v1" +ARCH_SD_V2_512 = "stable-diffusion-v2-512" +ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v" +ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base" +ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc. +# ARCH_SD3_UNKNOWN = "stable-diffusion-3" +ARCH_FLUX_1_DEV = "flux-1-dev" +ARCH_FLUX_1_UNKNOWN = "flux-1" + +ADAPTER_LORA = "lora" +ADAPTER_TEXTUAL_INVERSION = "textual-inversion" + +IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models" +IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI" +IMPL_DIFFUSERS = "diffusers" +IMPL_FLUX = "https://github.com/black-forest-labs/flux" + +PRED_TYPE_EPSILON = "epsilon" +PRED_TYPE_V = "v" + + +def load_bytes_in_safetensors(tensors): + bytes = safetensors.torch.save(tensors) + b = BytesIO(bytes) + + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") + + offset = n + 8 + b.seek(offset) + + return b.read() + + +def precalculate_safetensors_hashes(state_dict): + # calculate each tensor one by one to reduce memory usage + hash_sha256 = hashlib.sha256() + for tensor in state_dict.values(): + single_tensor_sd = {"tensor": tensor} + bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd) + hash_sha256.update(bytes_for_tensor) + + return f"0x{hash_sha256.hexdigest()}" + + +def update_hash_sha256(metadata: dict, state_dict: dict): + raise NotImplementedError + + +def build_metadata( + state_dict: Optional[dict], + v2: bool, + v_parameterization: bool, + sdxl: bool, + lora: bool, + textual_inversion: bool, + timestamp: float, + title: Optional[str] = None, + reso: Optional[Union[int, Tuple[int, int]]] = None, + is_stable_diffusion_ckpt: Optional[bool] = None, + author: Optional[str] = None, + description: Optional[str] = None, + license: Optional[str] = None, + tags: Optional[str] = None, + merged_from: Optional[str] = None, + timesteps: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + sd3: Optional[str] = None, + flux: Optional[str] = None, +): + """ + sd3: only supports "m", flux: only supports "dev" + """ + # if state_dict is None, hash is not calculated + + metadata = {} + metadata.update(BASE_METADATA) + + # TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する + # if state_dict is not None: + # hash = precalculate_safetensors_hashes(state_dict) + # metadata["modelspec.hash_sha256"] = hash + + if sdxl: + arch = ARCH_SD_XL_V1_BASE + elif sd3 is not None: + arch = ARCH_SD3_M + "-" + sd3 + elif flux is not None: + if flux == "dev": + arch = ARCH_FLUX_1_DEV + else: + arch = ARCH_FLUX_1_UNKNOWN + elif v2: + if v_parameterization: + arch = ARCH_SD_V2_768_V + else: + arch = ARCH_SD_V2_512 + else: + arch = ARCH_SD_V1 + + if lora: + arch += f"/{ADAPTER_LORA}" + elif textual_inversion: + arch += f"/{ADAPTER_TEXTUAL_INVERSION}" + + metadata["modelspec.architecture"] = arch + + if not lora and not textual_inversion and is_stable_diffusion_ckpt is None: + is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion + + if flux is not None: + # Flux + impl = IMPL_FLUX + elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: + # Stable Diffusion ckpt, TI, SDXL LoRA + impl = IMPL_STABILITY_AI + else: + # v1/v2 LoRA or Diffusers + impl = IMPL_DIFFUSERS + metadata["modelspec.implementation"] = impl + + if title is None: + if lora: + title = "LoRA" + elif textual_inversion: + title = "TextualInversion" + else: + title = "Checkpoint" + title += f"@{timestamp}" + metadata[MODELSPEC_TITLE] = title + + if author is not None: + metadata["modelspec.author"] = author + else: + del metadata["modelspec.author"] + + if description is not None: + metadata["modelspec.description"] = description + else: + del metadata["modelspec.description"] + + if merged_from is not None: + metadata["modelspec.merged_from"] = merged_from + else: + del metadata["modelspec.merged_from"] + + if license is not None: + metadata["modelspec.license"] = license + else: + del metadata["modelspec.license"] + + if tags is not None: + metadata["modelspec.tags"] = tags + else: + del metadata["modelspec.tags"] + + # remove microsecond from time + int_ts = int(timestamp) + + # time to iso-8601 compliant date + date = datetime.datetime.fromtimestamp(int_ts).isoformat() + metadata["modelspec.date"] = date + + if reso is not None: + # comma separated to tuple + if isinstance(reso, str): + reso = tuple(map(int, reso.split(","))) + if len(reso) == 1: + reso = (reso[0], reso[0]) + else: + # resolution is defined in dataset, so use default + if sdxl or sd3 is not None or flux is not None: + reso = 1024 + elif v2 and v_parameterization: + reso = 768 + else: + reso = 512 + if isinstance(reso, int): + reso = (reso, reso) + + metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}" + + if flux is not None: + del metadata["modelspec.prediction_type"] + elif v_parameterization: + metadata["modelspec.prediction_type"] = PRED_TYPE_V + else: + metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON + + if timesteps is not None: + if isinstance(timesteps, str) or isinstance(timesteps, int): + timesteps = (timesteps, timesteps) + if len(timesteps) == 1: + timesteps = (timesteps[0], timesteps[0]) + metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}" + else: + del metadata["modelspec.timestep_range"] + + if clip_skip is not None: + metadata["modelspec.encoder_layer"] = f"{clip_skip}" + else: + del metadata["modelspec.encoder_layer"] + + # # assert all values are filled + # assert all([v is not None for v in metadata.values()]), metadata + if not all([v is not None for v in metadata.values()]): + logger.error(f"Internal error: some metadata values are None: {metadata}") + + return metadata + + +# region utils + + +def get_title(metadata: dict) -> Optional[str]: + return metadata.get(MODELSPEC_TITLE, None) + + +def load_metadata_from_safetensors(model: str) -> dict: + if not model.endswith(".safetensors"): + return {} + + with safetensors.safe_open(model, framework="pt") as f: + metadata = f.metadata() + if metadata is None: + metadata = {} + return metadata + + +def build_merged_from(models: List[str]) -> str: + def get_title(model: str): + metadata = load_metadata_from_safetensors(model) + title = metadata.get(MODELSPEC_TITLE, None) + if title is None: + title = os.path.splitext(os.path.basename(model))[0] # use filename + return title + + titles = [get_title(model) for model in models] + return ", ".join(titles) + + +# endregion + + +r""" +if __name__ == "__main__": + import argparse + import torch + from safetensors.torch import load_file + from library import train_util + + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt", type=str, required=True) + args = parser.parse_args() + + print(f"Loading {args.ckpt}") + state_dict = load_file(args.ckpt) + + print(f"Calculating metadata") + metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0) + print(metadata) + del state_dict + + # by reference implementation + with open(args.ckpt, mode="rb") as file_data: + file_hash = hashlib.sha256() + head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix + header = json.loads(file_data.read(head_len[0])) # header itself, json string + content = ( + file_data.read() + ) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl. + file_hash.update(content) + # ===== Update the hash for modelspec ===== + by_ref = f"0x{file_hash.hexdigest()}" + print(by_ref) + print("is same?", by_ref == metadata["modelspec.hash_sha256"]) + +""" diff --git a/library/sd3_models.py b/library/sd3_models.py new file mode 100644 index 0000000000000000000000000000000000000000..527dc54e385654515d6e73c364236566ce3d0b43 --- /dev/null +++ b/library/sd3_models.py @@ -0,0 +1,1422 @@ +# some modules/classes are copied and modified from https://github.com/mcmonkey4eva/sd3-ref +# the original code is licensed under the MIT License + +# and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution! + +from dataclasses import dataclass +from functools import partial +import math + +from typing import Dict, List, Optional, Union +import einops +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint + +from . import custom_offloading_utils + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +memory_efficient_attention = None +try: + import xformers +except: + pass + +try: + from xformers.ops import memory_efficient_attention +except: + memory_efficient_attention = None + + +# region mmdit + + +@dataclass +class SD3Params: + patch_size: int + depth: int + num_patches: int + pos_embed_max_size: int + adm_in_channels: int + qk_norm: Optional[str] + x_block_self_attn_layers: list[int] + context_embedder_in_features: int + context_embedder_out_features: int + model_type: str + + +def get_2d_sincos_pos_embed( + embed_dim, + grid_size, + scaling_factor=None, + offset=None, +): + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + if scaling_factor is not None: + grid = grid / scaling_factor + if offset is not None: + grid = grid - offset + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_scaled_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, sample_size=64, base_size=16): + """ + This function is contributed by KohakuBlueleaf. Thanks for the contribution! + + Creates scaled 2D sinusoidal positional embeddings that maintain consistent relative positions + when the resolution differs from the training resolution. + + Args: + embed_dim (int): Dimension of the positional embedding. + grid_size (int or tuple): Size of the position grid (H, W). If int, assumes square grid. + cls_token (bool): Whether to include class token. Defaults to False. + extra_tokens (int): Number of extra tokens (e.g., cls_token). Defaults to 0. + sample_size (int): Reference resolution (typically training resolution). Defaults to 64. + base_size (int): Base grid size used during training. Defaults to 16. + + Returns: + numpy.ndarray: Positional embeddings of shape (H*W, embed_dim) or + (H*W + extra_tokens, embed_dim) if cls_token is True. + """ + # Convert grid_size to tuple if it's an integer + if isinstance(grid_size, int): + grid_size = (grid_size, grid_size) + + # Create normalized grid coordinates (0 to 1) + grid_h = np.arange(grid_size[0], dtype=np.float32) / grid_size[0] + grid_w = np.arange(grid_size[1], dtype=np.float32) / grid_size[1] + + # Calculate scaling factors for height and width + # This ensures that the central region matches the original resolution's embeddings + scale_h = base_size * grid_size[0] / (sample_size) + scale_w = base_size * grid_size[1] / (sample_size) + + # Calculate shift values to center the original resolution's embedding region + # This ensures that the central sample_size x sample_size region has similar + # positional embeddings to the original resolution + shift_h = 1 * scale_h * (grid_size[0] - sample_size) / (2 * grid_size[0]) + shift_w = 1 * scale_w * (grid_size[1] - sample_size) / (2 * grid_size[1]) + + # Apply scaling and shifting to create the final grid coordinates + grid_h = grid_h * scale_h - shift_h + grid_w = grid_w * scale_w - shift_w + + # Create 2D grid using meshgrid (note: w goes first) + grid = np.meshgrid(grid_w, grid_h) + grid = np.stack(grid, axis=0) + + # # Calculate the starting indices for the central region + # # This is used for debugging/visualization of the central region + # st_h = (grid_size[0] - sample_size) // 2 + # st_w = (grid_size[1] - sample_size) // 2 + # print(grid[:, st_h : st_h + sample_size, st_w : st_w + sample_size]) + + # Reshape grid for positional embedding calculation + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + + # Generate the sinusoidal positional embeddings + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + + # Add zeros for extra tokens (e.g., [CLS] token) if required + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + + return pos_embed + + +# if __name__ == "__main__": +# # This is what you get when you load SD3.5 state dict +# pos_emb = torch.from_numpy(get_scaled_2d_sincos_pos_embed( +# 1536, [384, 384], sample_size=64, base_size=16 +# )).float().unsqueeze(0) + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid_torch( + embed_dim, + pos, + device=None, + dtype=torch.float32, +): + omega = torch.arange(embed_dim // 2, device=device, dtype=dtype) + omega *= 2.0 / embed_dim + omega = 1.0 / 10000**omega + out = torch.outer(pos.reshape(-1), omega) + emb = torch.cat([out.sin(), out.cos()], dim=1) + return emb + + +def get_2d_sincos_pos_embed_torch( + embed_dim, + w, + h, + val_center=7.5, + val_magnitude=7.5, + device=None, + dtype=torch.float32, +): + small = min(h, w) + val_h = (h / small) * val_magnitude + val_w = (w / small) * val_magnitude + grid_h, grid_w = torch.meshgrid( + torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype), + torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype), + indexing="ij", + ) + emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype) + emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype) + emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D) + return emb + + +def modulate(x, shift, scale): + if shift is None: + shift = torch.zeros_like(scale) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def default(x, default_value): + if x is None: + return default_value + return x + + +def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + # freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + # device=t.device, dtype=t.dtype + # ) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(dtype=t.dtype) + return embedding + + +class PatchEmbed(nn.Module): + def __init__( + self, + img_size=256, + patch_size=4, + in_channels=3, + embed_dim=512, + norm_layer=None, + flatten=True, + bias=True, + strict_img_size=True, + dynamic_img_pad=False, + ): + # dynamic_img_pad and norm is omitted in SD3.5 + super().__init__() + self.patch_size = patch_size + self.flatten = flatten + self.strict_img_size = strict_img_size + self.dynamic_img_pad = dynamic_img_pad + if img_size is not None: + self.img_size = img_size + self.grid_size = img_size // patch_size + self.num_patches = self.grid_size**2 + else: + self.img_size = None + self.grid_size = None + self.num_patches = None + + self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size, bias=bias) + self.norm = nn.Identity() if norm_layer is None else norm_layer(embed_dim) + + def forward(self, x): + B, C, H, W = x.shape + + if self.dynamic_img_pad: + # Pad input so we won't have partial patch + pad_h = (self.patch_size - H % self.patch_size) % self.patch_size + pad_w = (self.patch_size - W % self.patch_size) % self.patch_size + x = nn.functional.pad(x, (0, pad_w, 0, pad_h), mode="reflect") + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + return x + + +# FinalLayer in mmdit.py +class UnPatch(nn.Module): + def __init__(self, hidden_size=512, patch_size=4, out_channels=3): + super().__init__() + self.patch_size = patch_size + self.c = out_channels + + # eps is default in mmdit.py + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size**2 * out_channels) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size), + ) + + def forward(self, x: torch.Tensor, cmod, H=None, W=None): + b, n, _ = x.shape + p = self.patch_size + c = self.c + if H is None and W is None: + w = h = int(n**0.5) + assert h * w == n + else: + h = H // p if H else n // (W // p) + w = W // p if W else n // h + assert h * w == n + + shift, scale = self.adaLN_modulation(cmod).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + + x = x.view(b, h, w, p, p, c) + x = x.permute(0, 5, 1, 3, 2, 4).contiguous() + x = x.view(b, c, h * p, w * p) + return x + + +class MLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=lambda: nn.GELU(), + norm_layer=None, + bias=True, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.use_conv = use_conv + + layer = partial(nn.Conv1d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = layer(in_features, hidden_features, bias=bias) + self.fc2 = layer(hidden_features, out_features, bias=bias) + self.act = act_layer() + self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.norm(x) + x = self.fc2(x) + return x + + +class TimestepEmbedding(nn.Module): + def __init__(self, hidden_size, freq_embed_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(freq_embed_size, hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + ) + self.freq_embed_size = freq_embed_size + + def forward(self, t, dtype=None, **kwargs): + t_freq = timestep_embedding(t, self.freq_embed_size).to(dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class Embedder(nn.Module): + def __init__(self, input_dim, hidden_size): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(input_dim, hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + ) + + def forward(self, x): + return self.mlp(x) + + +def rmsnorm(x, eps=1e-6): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +class RMSNorm(torch.nn.Module): + def __init__( + self, + dim: int, + elementwise_affine: bool = False, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + """ + super().__init__() + self.eps = eps + self.learnable_scale = elementwise_affine + if self.learnable_scale: + self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) + else: + self.register_parameter("weight", None) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + Args: + x (torch.Tensor): The input tensor. + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + """ + x = rmsnorm(x, eps=self.eps) + if self.learnable_scale: + return x * self.weight.to(device=x.device, dtype=x.dtype) + else: + return x + + +class SwiGLUFeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: float = None, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) + + +# Linears for SelfAttention in mmdit.py +class AttentionLinears(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + pre_only: bool = False, + qk_norm: Optional[str] = None, + ): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + if not pre_only: + self.proj = nn.Linear(dim, dim) + self.pre_only = pre_only + + if qk_norm == "rms": + self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + elif qk_norm == "ln": + self.ln_q = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + self.ln_k = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) + elif qk_norm is None: + self.ln_q = nn.Identity() + self.ln_k = nn.Identity() + else: + raise ValueError(qk_norm) + + def pre_attention(self, x: torch.Tensor) -> torch.Tensor: + """ + output: + q, k, v: [B, L, D] + """ + B, L, C = x.shape + qkv: torch.Tensor = self.qkv(x) + q, k, v = qkv.reshape(B, L, -1, self.head_dim).chunk(3, dim=2) + q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1) + k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1) + return (q, k, v) + + def post_attention(self, x: torch.Tensor) -> torch.Tensor: + assert not self.pre_only + x = self.proj(x) + return x + + +MEMORY_LAYOUTS = { + "torch": ( + lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2), + lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1), + lambda x: (1, x, 1, 1), + ), + "xformers": ( + lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim), + lambda x: x.reshape(x.shape[0], x.shape[1], -1), + lambda x: (1, 1, x, 1), + ), + "math": ( + lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2), + lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1), + lambda x: (1, x, 1, 1), + ), +} +# ATTN_FUNCTION = { +# "torch": F.scaled_dot_product_attention, +# "xformers": memory_efficient_attention, +# } + + +def vanilla_attention(q, k, v, mask, scale=None): + if scale is None: + scale = math.sqrt(q.size(-1)) + scores = torch.bmm(q, k.transpose(-1, -2)) / scale + if mask is not None: + mask = einops.rearrange(mask, "b ... -> b (...)") + max_neg_value = -torch.finfo(scores.dtype).max + mask = einops.repeat(mask, "b j -> (b h) j", h=q.size(-3)) + scores = scores.masked_fill(~mask, max_neg_value) + p_attn = F.softmax(scores, dim=-1) + return torch.bmm(p_attn, v) + + +def attention(q, k, v, head_dim, mask=None, scale=None, mode="xformers"): + """ + q, k, v: [B, L, D] + """ + pre_attn_layout = MEMORY_LAYOUTS[mode][0] + post_attn_layout = MEMORY_LAYOUTS[mode][1] + q = pre_attn_layout(q, head_dim) + k = pre_attn_layout(k, head_dim) + v = pre_attn_layout(v, head_dim) + + # scores = ATTN_FUNCTION[mode](q, k.to(q), v.to(q), mask, scale=scale) + if mode == "torch": + assert scale is None + scores = F.scaled_dot_product_attention(q, k.to(q), v.to(q), mask) # , scale=scale) + elif mode == "xformers": + scores = memory_efficient_attention(q, k.to(q), v.to(q), mask, scale=scale) + else: + scores = vanilla_attention(q, k.to(q), v.to(q), mask, scale=scale) + + scores = post_attn_layout(scores) + return scores + + +# DismantledBlock in mmdit.py +class SingleDiTBlock(nn.Module): + """ + A DiT block with gated adaptive layer norm (adaLN) conditioning. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: str = "xformers", + qkv_bias: bool = False, + pre_only: bool = False, + rmsnorm: bool = False, + scale_mod_only: bool = False, + swiglu: bool = False, + qk_norm: Optional[str] = None, + x_block_self_attn: bool = False, + **block_kwargs, + ): + super().__init__() + assert attn_mode in MEMORY_LAYOUTS + self.attn_mode = attn_mode + if not rmsnorm: + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + else: + self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=pre_only, qk_norm=qk_norm) + + self.x_block_self_attn = x_block_self_attn + if self.x_block_self_attn: + assert not pre_only + assert not scale_mod_only + self.attn2 = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=False, qk_norm=qk_norm) + + if not pre_only: + if not rmsnorm: + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + else: + self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + if not pre_only: + if not swiglu: + self.mlp = MLP( + in_features=hidden_size, + hidden_features=mlp_hidden_dim, + act_layer=lambda: nn.GELU(approximate="tanh"), + ) + else: + self.mlp = SwiGLUFeedForward( + dim=hidden_size, + hidden_dim=mlp_hidden_dim, + multiple_of=256, + ) + self.scale_mod_only = scale_mod_only + if self.x_block_self_attn: + n_mods = 9 + elif not scale_mod_only: + n_mods = 6 if not pre_only else 2 + else: + n_mods = 4 if not pre_only else 1 + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size)) + self.pre_only = pre_only + + def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + if not self.pre_only: + if not self.scale_mod_only: + (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(6, dim=-1) + else: + shift_msa = None + shift_mlp = None + (scale_msa, gate_msa, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(4, dim=-1) + qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) + return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp) + else: + if not self.scale_mod_only: + (shift_msa, scale_msa) = self.adaLN_modulation(c).chunk(2, dim=-1) + else: + shift_msa = None + scale_msa = self.adaLN_modulation(c) + qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) + return qkv, None + + def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + assert self.x_block_self_attn + (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2) = self.adaLN_modulation( + c + ).chunk(9, dim=1) + x_norm = self.norm1(x) + qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa)) + qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2)) + return qkv, qkv2, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2) + + def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp): + assert not self.pre_only + x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2, attn1_dropout: float = 0.0): + assert not self.pre_only + if attn1_dropout > 0.0: + # Use torch.bernoulli to implement dropout, only dropout the batch dimension + attn1_dropout = torch.bernoulli(torch.full((attn.size(0), 1, 1), 1 - attn1_dropout, device=attn.device)) + attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) * attn1_dropout + else: + attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) + x = x + attn_ + attn2_ = gate_msa2.unsqueeze(1) * self.attn2.post_attention(attn2) + x = x + attn2_ + mlp_ = gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + x = x + mlp_ + return x + + +# JointBlock + block_mixing in mmdit.py +class MMDiTBlock(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + pre_only = kwargs.pop("pre_only") + x_block_self_attn = kwargs.pop("x_block_self_attn") + + self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs) + self.x_block = SingleDiTBlock(*args, pre_only=False, x_block_self_attn=x_block_self_attn, **kwargs) + + self.head_dim = self.x_block.attn.head_dim + self.mode = self.x_block.attn_mode + self.gradient_checkpointing = False + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def _forward(self, context, x, c): + ctx_qkv, ctx_intermediate = self.context_block.pre_attention(context, c) + + if self.x_block.x_block_self_attn: + x_qkv, x_qkv2, x_intermediates = self.x_block.pre_attention_x(x, c) + else: + x_qkv, x_intermediates = self.x_block.pre_attention(x, c) + + ctx_len = ctx_qkv[0].size(1) + + q = torch.concat((ctx_qkv[0], x_qkv[0]), dim=1) + k = torch.concat((ctx_qkv[1], x_qkv[1]), dim=1) + v = torch.concat((ctx_qkv[2], x_qkv[2]), dim=1) + + attn = attention(q, k, v, head_dim=self.head_dim, mode=self.mode) + ctx_attn_out = attn[:, :ctx_len] + x_attn_out = attn[:, ctx_len:] + + if self.x_block.x_block_self_attn: + x_q2, x_k2, x_v2 = x_qkv2 + attn2 = attention(x_q2, x_k2, x_v2, self.x_block.attn2.num_heads, mode=self.mode) + x = self.x_block.post_attention_x(x_attn_out, attn2, *x_intermediates) + else: + x = self.x_block.post_attention(x_attn_out, *x_intermediates) + + if not self.context_block.pre_only: + context = self.context_block.post_attention(ctx_attn_out, *ctx_intermediate) + else: + context = None + + return context, x + + def forward(self, *args, **kwargs): + if self.training and self.gradient_checkpointing: + return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) + else: + return self._forward(*args, **kwargs) + + +class MMDiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + + # prepare pos_embed for latent size * 2 + POS_EMBED_MAX_RATIO = 1.5 + + def __init__( + self, + input_size: int = 32, + patch_size: int = 2, + in_channels: int = 4, + depth: int = 28, + # hidden_size: Optional[int] = None, + # num_heads: Optional[int] = None, + mlp_ratio: float = 4.0, + learn_sigma: bool = False, + adm_in_channels: Optional[int] = None, + context_embedder_in_features: Optional[int] = None, + context_embedder_out_features: Optional[int] = None, + use_checkpoint: bool = False, + register_length: int = 0, + attn_mode: str = "torch", + rmsnorm: bool = False, + scale_mod_only: bool = False, + swiglu: bool = False, + out_channels: Optional[int] = None, + pos_embed_scaling_factor: Optional[float] = None, + pos_embed_offset: Optional[float] = None, + pos_embed_max_size: Optional[int] = None, + num_patches=None, + qk_norm: Optional[str] = None, + x_block_self_attn_layers: Optional[list[int]] = [], + qkv_bias: bool = True, + pos_emb_random_crop_rate: float = 0.0, + use_scaled_pos_embed: bool = False, + pos_embed_latent_sizes: Optional[list[int]] = None, + model_type: str = "sd3m", + ): + super().__init__() + self._model_type = model_type + self.learn_sigma = learn_sigma + self.in_channels = in_channels + default_out_channels = in_channels * 2 if learn_sigma else in_channels + self.out_channels = default(out_channels, default_out_channels) + self.patch_size = patch_size + self.pos_embed_scaling_factor = pos_embed_scaling_factor + self.pos_embed_offset = pos_embed_offset + self.pos_embed_max_size = pos_embed_max_size + self.x_block_self_attn_layers = x_block_self_attn_layers + self.pos_emb_random_crop_rate = pos_emb_random_crop_rate + self.gradient_checkpointing = use_checkpoint + + # hidden_size = default(hidden_size, 64 * depth) + # num_heads = default(num_heads, hidden_size // 64) + + # apply magic --> this defines a head_size of 64 + self.hidden_size = 64 * depth + num_heads = depth + + self.num_heads = num_heads + + self.enable_scaled_pos_embed(use_scaled_pos_embed, pos_embed_latent_sizes) + + self.x_embedder = PatchEmbed( + input_size, + patch_size, + in_channels, + self.hidden_size, + bias=True, + strict_img_size=self.pos_embed_max_size is None, + ) + self.t_embedder = TimestepEmbedding(self.hidden_size) + + self.y_embedder = None + if adm_in_channels is not None: + assert isinstance(adm_in_channels, int) + self.y_embedder = Embedder(adm_in_channels, self.hidden_size) + + if context_embedder_in_features is not None: + self.context_embedder = nn.Linear(context_embedder_in_features, context_embedder_out_features) + else: + self.context_embedder = nn.Identity() + + self.register_length = register_length + if self.register_length > 0: + self.register = nn.Parameter(torch.randn(1, register_length, self.hidden_size)) + + # num_patches = self.x_embedder.num_patches + # Will use fixed sin-cos embedding: + # just use a buffer already + if num_patches is not None: + self.register_buffer( + "pos_embed", + torch.empty(1, num_patches, self.hidden_size), + ) + else: + self.pos_embed = None + + self.use_checkpoint = use_checkpoint + self.joint_blocks = nn.ModuleList( + [ + MMDiTBlock( + self.hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + qkv_bias=qkv_bias, + pre_only=i == depth - 1, + rmsnorm=rmsnorm, + scale_mod_only=scale_mod_only, + swiglu=swiglu, + qk_norm=qk_norm, + x_block_self_attn=(i in self.x_block_self_attn_layers), + ) + for i in range(depth) + ] + ) + for block in self.joint_blocks: + block.gradient_checkpointing = use_checkpoint + + self.final_layer = UnPatch(self.hidden_size, patch_size, self.out_channels) + # self.initialize_weights() + + self.blocks_to_swap = None + self.offloader = None + self.num_blocks = len(self.joint_blocks) + + def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Optional[list[int]]): + self.use_scaled_pos_embed = use_scaled_pos_embed + + if self.use_scaled_pos_embed: + # remove pos_embed to free up memory up to 0.4 GB + self.pos_embed = None + + # remove duplicates and sort latent sizes in ascending order + latent_sizes = list(set(latent_sizes)) + latent_sizes = sorted(latent_sizes) + + patched_sizes = [latent_size // self.patch_size for latent_size in latent_sizes] + + # calculate value range for each latent area: this is used to determine the pos_emb size from the latent shape + max_areas = [] + for i in range(1, len(patched_sizes)): + prev_area = patched_sizes[i - 1] ** 2 + area = patched_sizes[i] ** 2 + max_areas.append((prev_area + area) // 2) + + # area of the last latent size, if the latent size exceeds this, error will be raised + max_areas.append(int((patched_sizes[-1] * MMDiT.POS_EMBED_MAX_RATIO) ** 2)) + # print("max_areas", max_areas) + + self.resolution_area_to_latent_size = [(area, latent_size) for area, latent_size in zip(max_areas, patched_sizes)] + + self.resolution_pos_embeds = {} + for patched_size in patched_sizes: + grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO) + pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0) + self.resolution_pos_embeds[patched_size] = pos_embed + # print(f"pos_embed for {patched_size}x{patched_size} latent size: {pos_embed.shape}") + + else: + self.resolution_area_to_latent_size = None + self.resolution_pos_embeds = None + + @property + def model_type(self): + return self._model_type + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + for block in self.joint_blocks: + block.enable_gradient_checkpointing() + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + for block in self.joint_blocks: + block.disable_gradient_checkpointing() + + def initialize_weights(self): + # TODO: Init context_embedder? + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize (and freeze) pos_embed by sin-cos embedding + if self.pos_embed is not None: + pos_embed = get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], + int(self.pos_embed.shape[-2] ** 0.5), + scaling_factor=self.pos_embed_scaling_factor, + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + if getattr(self, "y_embedder", None) is not None: + nn.init.normal_(self.y_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.y_embedder.mlp[2].weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.joint_blocks: + nn.init.constant_(block.x_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.x_block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(block.context_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.context_block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def set_pos_emb_random_crop_rate(self, rate: float): + self.pos_emb_random_crop_rate = rate + + def cropped_pos_embed(self, h, w, device=None, random_crop: bool = False): + p = self.x_embedder.patch_size + # patched size + h = (h + 1) // p + w = (w + 1) // p + if self.pos_embed is None: # should not happen + return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device) + assert self.pos_embed_max_size is not None + assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size) + assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size) + + if not random_crop: + top = (self.pos_embed_max_size - h) // 2 + left = (self.pos_embed_max_size - w) // 2 + else: + top = torch.randint(0, self.pos_embed_max_size - h + 1, (1,)).item() + left = torch.randint(0, self.pos_embed_max_size - w + 1, (1,)).item() + + spatial_pos_embed = self.pos_embed.reshape( + 1, + self.pos_embed_max_size, + self.pos_embed_max_size, + self.pos_embed.shape[-1], + ) + spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] + spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) + return spatial_pos_embed + + def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: bool = False): + p = self.x_embedder.patch_size + # patched size + h = (h + 1) // p + w = (w + 1) // p + + # select pos_embed size based on area + area = h * w + patched_size = None + for area_, patched_size_ in self.resolution_area_to_latent_size: + if area <= area_: + patched_size = patched_size_ + break + if patched_size is None: + # raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.") + # use largest latent size + patched_size = self.resolution_area_to_latent_size[-1][1] + + pos_embed = self.resolution_pos_embeds[patched_size] + pos_embed_size = round(math.sqrt(pos_embed.shape[1])) # max size, patched_size * POS_EMBED_MAX_RATIO + if h > pos_embed_size or w > pos_embed_size: + # # fallback to normal pos_embed + # return self.cropped_pos_embed(h * p, w * p, device=device, random_crop=random_crop) + # extend pos_embed size + logger.warning( + f"Add new pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide." + ) + patched_size = max(h, w) + grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO) + pos_embed_size = grid_size + pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0) + self.resolution_pos_embeds[patched_size] = pos_embed + logger.info(f"Added pos_embed for size {patched_size}x{patched_size}") + + # print(torch.allclose(pos_embed.to(torch.float32).cpu(), self.pos_embed.to(torch.float32).cpu(), atol=5e-2)) + # diff = pos_embed.to(torch.float32).cpu() - self.pos_embed.to(torch.float32).cpu() + # print(diff.abs().max(), diff.abs().mean()) + + # insert to resolution_area_to_latent_size, by adding and sorting + area = pos_embed_size**2 + self.resolution_area_to_latent_size.append((area, patched_size)) + self.resolution_area_to_latent_size = sorted(self.resolution_area_to_latent_size) + + if not random_crop: + top = (pos_embed_size - h) // 2 + left = (pos_embed_size - w) // 2 + else: + top = torch.randint(0, pos_embed_size - h + 1, (1,)).item() + left = torch.randint(0, pos_embed_size - w + 1, (1,)).item() + + if pos_embed.device != device: + pos_embed = pos_embed.to(device) + # which is better to update device, or transfer every time to device? -> 64x64 emb is 96*96*1536*4=56MB. It's okay to update device. + self.resolution_pos_embeds[patched_size] = pos_embed # update device + if pos_embed.dtype != dtype: + pos_embed = pos_embed.to(dtype) + self.resolution_pos_embeds[patched_size] = pos_embed # update dtype + + spatial_pos_embed = pos_embed.reshape(1, pos_embed_size, pos_embed_size, pos_embed.shape[-1]) + spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] + spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) + # print( + # f"patched size: {h}x{w}, pos_embed size: {pos_embed_size}, pos_embed shape: {pos_embed.shape}, top: {top}, left: {left}" + # ) + return spatial_pos_embed + + def enable_block_swap(self, num_blocks: int, device: torch.device): + self.blocks_to_swap = num_blocks + + assert ( + self.blocks_to_swap <= self.num_blocks - 2 + ), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks." + + self.offloader = custom_offloading_utils.ModelOffloader( + self.joint_blocks, self.num_blocks, self.blocks_to_swap, device # , debug=True + ) + print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.") + + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage + if self.blocks_to_swap: + save_blocks = self.joint_blocks + self.joint_blocks = None + + self.to(device) + + if self.blocks_to_swap: + self.joint_blocks = save_blocks + + def prepare_block_swap_before_forward(self): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self.offloader.prepare_block_devices_before_forward(self.joint_blocks) + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + y: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass of DiT. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, D) tensor of class labels + """ + pos_emb_random_crop = ( + False if self.pos_emb_random_crop_rate == 0.0 else torch.rand(1).item() < self.pos_emb_random_crop_rate + ) + + B, C, H, W = x.shape + + # x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype) + if not self.use_scaled_pos_embed: + pos_embed = self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype) + else: + # print(f"Using scaled pos_embed for size {H}x{W}") + pos_embed = self.cropped_scaled_pos_embed(H, W, device=x.device, dtype=x.dtype, random_crop=pos_emb_random_crop) + x = self.x_embedder(x) + pos_embed + del pos_embed + + c = self.t_embedder(t, dtype=x.dtype) # (N, D) + if y is not None and self.y_embedder is not None: + y = self.y_embedder(y) # (N, D) + c = c + y # (N, D) + + if context is not None: + context = self.context_embedder(context) + + if self.register_length > 0: + context = torch.cat( + (einops.repeat(self.register, "1 ... -> b ...", b=x.shape[0]), default(context, torch.Tensor([]).type_as(x))), 1 + ) + + if not self.blocks_to_swap: + for block in self.joint_blocks: + context, x = block(context, x, c) + else: + for block_idx, block in enumerate(self.joint_blocks): + self.offloader.wait_for_block(block_idx) + + context, x = block(context, x, c) + + self.offloader.submit_move_blocks(self.joint_blocks, block_idx) + + x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify + return x[:, :, :H, :W] + + +def create_sd3_mmdit(params: SD3Params, attn_mode: str = "torch") -> MMDiT: + mmdit = MMDiT( + input_size=None, + pos_embed_max_size=params.pos_embed_max_size, + patch_size=params.patch_size, + in_channels=16, + adm_in_channels=params.adm_in_channels, + context_embedder_in_features=params.context_embedder_in_features, + context_embedder_out_features=params.context_embedder_out_features, + depth=params.depth, + mlp_ratio=4, + qk_norm=params.qk_norm, + x_block_self_attn_layers=params.x_block_self_attn_layers, + num_patches=params.num_patches, + attn_mode=attn_mode, + model_type=params.model_type, + ) + return mmdit + + +# endregion + +# region VAE + +VAE_SCALE_FACTOR = 1.5305 +VAE_SHIFT_FACTOR = 0.0609 + + +def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) + + +class ResnetBlock(torch.nn.Module): + def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = Normalize(in_channels, dtype=dtype, device=device) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.norm2 = Normalize(out_channels, dtype=dtype, device=device) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + if self.in_channels != self.out_channels: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device + ) + else: + self.nin_shortcut = None + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, x): + hidden = x + hidden = self.norm1(hidden) + hidden = self.swish(hidden) + hidden = self.conv1(hidden) + hidden = self.norm2(hidden) + hidden = self.swish(hidden) + hidden = self.conv2(hidden) + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + return x + hidden + + +class AttnBlock(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.norm = Normalize(in_channels, dtype=dtype, device=device) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) + + def forward(self, x): + hidden = self.norm(x) + q = self.q(hidden) + k = self.k(hidden) + v = self.v(hidden) + b, c, h, w = q.shape + q, k, v = map(lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)) + hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default + hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + hidden = self.proj_out(hidden) + return x + hidden + + +class Downsample(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0, dtype=dtype, device=device) + + def forward(self, x): + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + + def forward(self, x): + org_dtype = x.dtype + if x.dtype == torch.bfloat16: + x = x.to(torch.float32) + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if x.dtype != org_dtype: + x = x.to(org_dtype) + x = self.conv(x) + return x + + +class VAEEncoder(torch.nn.Module): + def __init__( + self, ch=128, ch_mult=(1, 2, 4, 4), num_res_blocks=2, in_channels=3, z_channels=16, dtype=torch.float32, device=None + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = torch.nn.ModuleList() + for i_level in range(self.num_resolutions): + block = torch.nn.ModuleList() + attn = torch.nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device)) + block_in = block_out + down = torch.nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, dtype=dtype, device=device) + self.down.append(down) + # middle + self.mid = torch.nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + # end + self.norm_out = Normalize(block_in, dtype=dtype, device=device) + self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, x): + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = self.swish(h) + h = self.conv_out(h) + return h + + +class VAEDecoder(torch.nn.Module): + def __init__( + self, + ch=128, + out_ch=3, + ch_mult=(1, 2, 4, 4), + num_res_blocks=2, + resolution=256, + z_channels=16, + dtype=torch.float32, + device=None, + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + # middle + self.mid = torch.nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) + # upsampling + self.up = torch.nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = torch.nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device)) + block_in = block_out + up = torch.nn.Module() + up.block = block + if i_level != 0: + up.upsample = Upsample(block_in, dtype=dtype, device=device) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + # end + self.norm_out = Normalize(block_in, dtype=dtype, device=device) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, z): + # z to block_in + hidden = self.conv_in(z) + # middle + hidden = self.mid.block_1(hidden) + hidden = self.mid.attn_1(hidden) + hidden = self.mid.block_2(hidden) + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + hidden = self.up[i_level].block[i_block](hidden) + if i_level != 0: + hidden = self.up[i_level].upsample(hidden) + # end + hidden = self.norm_out(hidden) + hidden = self.swish(hidden) + hidden = self.conv_out(hidden) + return hidden + + +class SDVAE(torch.nn.Module): + def __init__(self, dtype=torch.float32, device=None): + super().__init__() + self.encoder = VAEEncoder(dtype=dtype, device=device) + self.decoder = VAEDecoder(dtype=dtype, device=device) + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + # @torch.autocast("cuda", dtype=torch.float16) + def decode(self, latent): + return self.decoder(latent) + + # @torch.autocast("cuda", dtype=torch.float16) + def encode(self, image): + hidden = self.encoder(image) + mean, logvar = torch.chunk(hidden, 2, dim=1) + logvar = torch.clamp(logvar, -30.0, 20.0) + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + + @staticmethod + def process_in(latent): + return (latent - VAE_SHIFT_FACTOR) * VAE_SCALE_FACTOR + + @staticmethod + def process_out(latent): + return (latent / VAE_SCALE_FACTOR) + VAE_SHIFT_FACTOR + + +# endregion diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..355f0856ef728e812baf844fee1a8899e8e58b71 --- /dev/null +++ b/library/sd3_train_utils.py @@ -0,0 +1,938 @@ +import argparse +import math +import os +import toml +import json +import time +from typing import Dict, List, Optional, Tuple, Union + +import torch +from safetensors.torch import save_file +from accelerate import Accelerator, PartialState +from tqdm import tqdm +from PIL import Image +from transformers import CLIPTextModelWithProjection, T5EncoderModel + +from . import sd3_models, sd3_utils, strategy_base, train_util +from .device_utils import init_ipex, clean_memory_on_device +from comfy.utils import ProgressBar +init_ipex() + +# from transformers import CLIPTokenizer +# from library import model_util +# , sdxl_model_util, train_util, sdxl_original_unet +# from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from . import sd3_models, sd3_utils, strategy_base, train_util + + +def save_models( + ckpt_path: str, + mmdit: Optional[sd3_models.MMDiT], + vae: Optional[sd3_models.SDVAE], + clip_l: Optional[CLIPTextModelWithProjection], + clip_g: Optional[CLIPTextModelWithProjection], + t5xxl: Optional[T5EncoderModel], + sai_metadata: Optional[dict], + save_dtype: Optional[torch.dtype] = None, +): + r""" + Save models to checkpoint file. Only supports unified checkpoint format. + """ + + state_dict = {} + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + update_sd("model.diffusion_model.", mmdit.state_dict()) + update_sd("first_stage_model.", vae.state_dict()) + + # do not support unified checkpoint format for now + # if clip_l is not None: + # update_sd("text_encoders.clip_l.", clip_l.state_dict()) + # if clip_g is not None: + # update_sd("text_encoders.clip_g.", clip_g.state_dict()) + # if t5xxl is not None: + # update_sd("text_encoders.t5xxl.", t5xxl.state_dict()) + + save_file(state_dict, ckpt_path, metadata=sai_metadata) + + if clip_l is not None: + clip_l_path = ckpt_path.replace(".safetensors", "_clip_l.safetensors") + save_file(clip_l.state_dict(), clip_l_path) + if clip_g is not None: + clip_g_path = ckpt_path.replace(".safetensors", "_clip_g.safetensors") + save_file(clip_g.state_dict(), clip_g_path) + if t5xxl is not None: + t5xxl_path = ckpt_path.replace(".safetensors", "_t5xxl.safetensors") + t5xxl_state_dict = t5xxl.state_dict() + + # replace "shared.weight" with copy of it to avoid annoying shared tensor error on safetensors.save_file + shared_weight = t5xxl_state_dict["shared.weight"] + shared_weight_copy = shared_weight.detach().clone() + t5xxl_state_dict["shared.weight"] = shared_weight_copy + + save_file(t5xxl_state_dict, t5xxl_path) + + +def save_sd3_model_on_train_end( + args: argparse.Namespace, + save_dtype: torch.dtype, + epoch: int, + global_step: int, + clip_l: Optional[CLIPTextModelWithProjection], + clip_g: Optional[CLIPTextModelWithProjection], + t5xxl: Optional[T5EncoderModel], + mmdit: sd3_models.MMDiT, + vae: sd3_models.SDVAE, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec( + None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type + ) + save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype) + + train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None) + + +# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している +# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 +def save_sd3_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + clip_l: Optional[CLIPTextModelWithProjection], + clip_g: Optional[CLIPTextModelWithProjection], + t5xxl: Optional[T5EncoderModel], + mmdit: sd3_models.MMDiT, + vae: sd3_models.SDVAE, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec( + None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type + ) + save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype) + + train_util.save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + True, + True, + epoch, + num_train_epochs, + global_step, + sd_saver, + None, + ) + + +def add_sd3_training_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--clip_l", + type=str, + required=False, + help="CLIP-L model path. if not specified, use ckpt's state_dict / CLIP-Lモデルのパス。指定しない場合はckptのstate_dictを使用", + ) + parser.add_argument( + "--clip_g", + type=str, + required=False, + help="CLIP-G model path. if not specified, use ckpt's state_dict / CLIP-Gモデルのパス。指定しない場合はckptのstate_dictを使用", + ) + parser.add_argument( + "--t5xxl", + type=str, + required=False, + help="T5-XXL model path. if not specified, use ckpt's state_dict / T5-XXLモデルのパス。指定しない場合はckptのstate_dictを使用", + ) + parser.add_argument( + "--save_clip", + action="store_true", + help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません", + ) + parser.add_argument( + "--save_t5xxl", + action="store_true", + help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません", + ) + + parser.add_argument( + "--t5xxl_device", + type=str, + default=None, + help="[DOES NOT WORK] not supported yet. T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用", + ) + parser.add_argument( + "--t5xxl_dtype", + type=str, + default=None, + help="[DOES NOT WORK] not supported yet. T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtype(mixed precisionから)を使用", + ) + + parser.add_argument( + "--t5xxl_max_token_length", + type=int, + default=256, + help="maximum token length for T5-XXL. 256 is the default value / T5-XXLの最大トークン長。デフォルトは256", + ) + parser.add_argument( + "--apply_lg_attn_mask", + action="store_true", + help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスク(ゼロ埋め)を適用する", + ) + parser.add_argument( + "--apply_t5_attn_mask", + action="store_true", + help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する", + ) + parser.add_argument( + "--clip_l_dropout_rate", + type=float, + default=0.0, + help="Dropout rate for CLIP-L encoder, default is 0.0 / CLIP-Lエンコーダのドロップアウト率、デフォルトは0.0", + ) + parser.add_argument( + "--clip_g_dropout_rate", + type=float, + default=0.0, + help="Dropout rate for CLIP-G encoder, default is 0.0 / CLIP-Gエンコーダのドロップアウト率、デフォルトは0.0", + ) + parser.add_argument( + "--t5_dropout_rate", + type=float, + default=0.0, + help="Dropout rate for T5 encoder, default is 0.0 / T5エンコーダのドロップアウト率、デフォルトは0.0", + ) + parser.add_argument( + "--pos_emb_random_crop_rate", + type=float, + default=0.0, + help="Random crop rate for positional embeddings, default is 0.0. Only for SD3.5M" + " / 位置埋め込みのランダムクロップ率、デフォルトは0.0。SD3.5M以外では予期しない動作になります", + ) + parser.add_argument( + "--enable_scaled_pos_embed", + action="store_true", + help="Scale position embeddings for each resolution during multi-resolution training. Only for SD3.5M" + " / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります", + ) + + # Dependencies of Diffusers noise sampler has been removed for clarity in training + + parser.add_argument( + "--training_shift", + type=float, + default=1.0, + help="Discrete flow shift for training timestep distribution adjustment, applied in addition to the weighting scheme, default is 1.0. /タイムステップ分布のための離散フローシフト、重み付けスキームの上に適用される、デフォルトは1.0。", + ) + + +def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): + assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" + if args.v_parameterization: + logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") + + if args.clip_skip is not None: + logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") + + # if args.multires_noise_iterations: + # logger.info( + # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります" + # ) + # else: + # if args.noise_offset is None: + # args.noise_offset = DEFAULT_NOISE_OFFSET + # elif args.noise_offset != DEFAULT_NOISE_OFFSET: + # logger.info( + # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています" + # ) + # logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") + + assert ( + not hasattr(args, "weighted_captions") or not args.weighted_captions + ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" + + if supportTextEncoderCaching: + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + args.cache_text_encoder_outputs = True + logger.warning( + "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / " + + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました" + ) + + +# temporary copied from sd3_minimal_inferece.py + + +def get_all_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps): + start = sampling.timestep(sampling.sigma_max) + end = sampling.timestep(sampling.sigma_min) + timesteps = torch.linspace(start, end, steps) + sigs = [] + for x in range(len(timesteps)): + ts = timesteps[x] + sigs.append(sampling.sigma(ts)) + sigs += [0.0] + return torch.FloatTensor(sigs) + + +def max_denoise(model_sampling, sigmas): + max_sigma = float(model_sampling.sigma_max) + sigma = float(sigmas[0]) + return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma + + +def do_sample( + height: int, + width: int, + seed: int, + cond: Tuple[torch.Tensor, torch.Tensor], + neg_cond: Tuple[torch.Tensor, torch.Tensor], + mmdit: sd3_models.MMDiT, + steps: int, + guidance_scale: float, + dtype: torch.dtype, + device: str, +): + latent = torch.zeros(1, 16, height // 8, width // 8, device=device) + latent = latent.to(dtype).to(device) + + # noise = get_noise(seed, latent).to(device) + if seed is not None: + generator = torch.manual_seed(seed) + else: + generator = None + noise = ( + torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu") + .to(latent.dtype) + .to(device) + ) + + model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3 + + sigmas = get_all_sigmas(model_sampling, steps).to(device) + + noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas)) + + c_crossattn = torch.cat([cond[0], neg_cond[0]]).to(device).to(dtype) + y = torch.cat([cond[1], neg_cond[1]]).to(device).to(dtype) + + x = noise_scaled.to(device).to(dtype) + # print(x.shape) + + # with torch.no_grad(): + comfy_pbar = ProgressBar(len(sigmas) - 1) + for i in tqdm(range(len(sigmas) - 1)): + sigma_hat = sigmas[i] + + timestep = model_sampling.timestep(sigma_hat).float() + timestep = torch.FloatTensor([timestep, timestep]).to(device) + + x_c_nc = torch.cat([x, x], dim=0) + # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) + + mmdit.prepare_block_swap_before_forward() + model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) + model_output = model_output.float() + batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) + + pos_out, neg_out = batched.chunk(2) + denoised = neg_out + (pos_out - neg_out) * guidance_scale + # print(denoised.shape) + + # d = to_d(x, sigma_hat, denoised) + dims_to_append = x.ndim - sigma_hat.ndim + sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append] + # print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape) + """Converts a denoiser output to a Karras ODE derivative.""" + d = (x - denoised) / sigma_hat_dims + + dt = sigmas[i + 1] - sigma_hat + + # Euler method + x = x + d * dt + x = x.to(dtype) + comfy_pbar.update(1) + + mmdit.prepare_block_swap_before_forward() + return x + + +def sample_images( + accelerator: Accelerator, + args: argparse.Namespace, + epoch, + steps, + mmdit, + vae, + text_encoders, + sample_prompts_te_outputs, + prompt_replacement=None, + validation_settings=None, +): + logger.info("") + logger.info(f"generating sample images at step: {steps}") + + # unwrap unet and text_encoder(s) + mmdit = accelerator.unwrap_model(mmdit) + text_encoders = None if text_encoders is None else [accelerator.unwrap_model(te) for te in text_encoders] + # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) + + prompts = [] + for line in args.sample_prompts: + line = line.strip() + if len(line) > 0 and line[0] != "#": + prompts.append(line) + + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + from .train_util import line_to_prompt_dict + + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass + + with torch.no_grad(), accelerator.autocast(): + image_tensor_list = [] + for prompt_dict in prompts: + image_tensor = sample_image_inference( + accelerator, + args, + mmdit, + text_encoders, + vae, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, + validation_settings + ) + print(f"Sampled image shape: {image_tensor.shape}") + image_tensor_list.append(image_tensor) + + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + clean_memory_on_device(accelerator.device) + return torch.cat(image_tensor_list, dim=0) + + +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + mmdit: sd3_models.MMDiT, + text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]], + vae: sd3_models.SDVAE, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + validation_settings=None, + prompt_replacement=None, + +): + assert isinstance(prompt_dict, dict) + if validation_settings is not None: + sample_steps = validation_settings["steps"] + width = validation_settings["width"] + height = validation_settings["height"] + scale = validation_settings["guidance_scale"] + seed = validation_settings["seed"] + else: + sample_steps = prompt_dict.get("sample_steps", 30) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 7.5) + seed = prompt_dict.get("seed") + # controlnet_image = prompt_dict.get("controlnet_image") + negative_prompt = prompt_dict.get("negative_prompt") + prompt: str = prompt_dict.get("prompt", "") + # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() + + if negative_prompt is None: + negative_prompt = "" + + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + logger.info(f"prompt: {prompt}") + logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {scale}") + # logger.info(f"sample_sampler: {sampler_name}") + if seed is not None: + logger.info(f"seed: {seed}") + + # encode prompts + tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + + def encode_prompt(prpt): + text_encoder_conds = [] + if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs: + text_encoder_conds = sample_prompts_te_outputs[prpt] + print(f"Using cached text encoder outputs for prompt: {prpt}") + if text_encoders is not None: + print(f"Encoding prompt: {prpt}") + tokens_and_masks = tokenize_strategy.tokenize(prpt) + # strategy has apply_t5_attn_mask option + encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) + + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] + return text_encoder_conds + + lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(prompt) + cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + # encode negative prompts + lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(negative_prompt) + neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled) + + # sample image + clean_memory_on_device(accelerator.device) + with accelerator.autocast(), torch.no_grad(): + # mmdit may be fp8, so we need weight_dtype here. vae is always in that dtype. + latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, vae.dtype, accelerator.device) + + # latent to image + clean_memory_on_device(accelerator.device) + org_vae_device = vae.device # will be on cpu + vae.to(accelerator.device) + latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype)) + image_tensor = vae.decode(latents) + vae.to(org_vae_device) + clean_memory_on_device(accelerator.device) + + image = image_tensor.float() + image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] + decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2) + decoded_np = decoded_np.astype(np.uint8) + + image = Image.fromarray(decoded_np) + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image.save(os.path.join(save_dir, img_filename)) + return image_tensor + + # # send images to wandb if enabled + # if "wandb" in [tracker.name for tracker in accelerator.trackers]: + # wandb_tracker = accelerator.get_tracker("wandb") + + # import wandb + + # # not to commit images to avoid inconsistency between training and logging steps + # wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption + + +# region Diffusers + + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils.torch_utils import randn_tensor +from diffusers.utils import BaseOutput + + +@dataclass +class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + ): + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + + self._step_index = None + self._begin_index = None + + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def scale_noise( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Forward process in flow-matching + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + timesteps = np.linspace(self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps) + + sigmas = timesteps / self.config.num_train_timesteps + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + + timesteps = sigmas * self.config.num_train_timesteps + self.timesteps = timesteps.to(device=device) + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self._step_index = None + self._begin_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + sigma = self.sigmas[self.step_index] + + gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + + noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator) + + eps = noise * s_noise + sigma_hat = sigma * (gamma + 1) + + if gamma > 0: + sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + # NOTE: "original_sample" should not be an expected prediction_type but is left in for + # backwards compatibility + + # if self.config.prediction_type == "vector_field": + + denoised = sample - model_output * sigma + # 2. Convert to an ODE derivative + derivative = (sample - denoised) / sigma_hat + + dt = self.sigmas[self.step_index + 1] - sigma_hat + + prev_sample = sample + derivative * dt + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + + def __len__(self): + return self.config.num_train_timesteps + + +def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None +): + """Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device="cpu") + return u + + +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) + return weighting + + +# endregion + + +def get_noisy_model_input_and_timesteps(args, latents, noise, device, dtype) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz = latents.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + t_min = args.min_timestep if args.min_timestep is not None else 0 + t_max = args.max_timestep if args.max_timestep is not None else 1000 + shift = args.training_shift + + # weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details) + u = (u * shift) / (1 + (shift - 1) * u) + + indices = (u * (t_max - t_min) + t_min).long() + timesteps = indices.to(device=device, dtype=dtype) + + # sigmas according to flowmatching + sigmas = timesteps / 1000 + sigmas = sigmas.view(-1, 1, 1, 1) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + + return noisy_model_input, timesteps, sigmas diff --git a/library/sd3_utils.py b/library/sd3_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..76dd3c18bb020cfa01dfac04d7efc165c2c00849 --- /dev/null +++ b/library/sd3_utils.py @@ -0,0 +1,296 @@ +import math +import re +from typing import Dict, List, Optional, Union +import torch +from accelerate import init_empty_weights +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPConfig, CLIPTextConfig + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from . import sd3_models + +# region models + +# TODO remove dependency on flux_utils +from .utils import load_safetensors +from .flux_utils import load_t5xxl as flux_utils_load_t5xxl + + +def analyze_state_dict_state(state_dict: Dict, prefix: str = ""): + logger.info(f"Analyzing state dict state...") + + # analyze configs + patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2] + depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64 + num_patches = state_dict[f"{prefix}pos_embed"].shape[1] + pos_embed_max_size = round(math.sqrt(num_patches)) + adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1] + context_shape = state_dict[f"{prefix}context_embedder.weight"].shape + qk_norm = "rms" if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in state_dict.keys() else None + + # x_block_self_attn_layers.append(int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1])) + x_block_self_attn_layers = [] + re_attn = re.compile(r"\.(\d+)\.x_block\.attn2\.ln_k\.weight") + for key in list(state_dict.keys()): + m = re_attn.search(key) + if m: + x_block_self_attn_layers.append(int(m.group(1))) + + context_embedder_in_features = context_shape[1] + context_embedder_out_features = context_shape[0] + + # only supports 3-5-large, medium or 3-medium + if qk_norm is not None: + if len(x_block_self_attn_layers) == 0: + model_type = "3-5-large" + else: + model_type = "3-5-medium" + else: + model_type = "3-medium" + + params = sd3_models.SD3Params( + patch_size=patch_size, + depth=depth, + num_patches=num_patches, + pos_embed_max_size=pos_embed_max_size, + adm_in_channels=adm_in_channels, + qk_norm=qk_norm, + x_block_self_attn_layers=x_block_self_attn_layers, + context_embedder_in_features=context_embedder_in_features, + context_embedder_out_features=context_embedder_out_features, + model_type=model_type, + ) + logger.info(f"Analyzed state dict state: {params}") + return params + + +def load_mmdit( + state_dict: Dict, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], attn_mode: str = "torch" +) -> sd3_models.MMDiT: + mmdit_sd = {} + + mmdit_prefix = "model.diffusion_model." + for k in list(state_dict.keys()): + if k.startswith(mmdit_prefix): + mmdit_sd[k[len(mmdit_prefix) :]] = state_dict.pop(k) + + # load MMDiT + logger.info("Building MMDit") + params = analyze_state_dict_state(mmdit_sd) + with init_empty_weights(): + mmdit = sd3_models.create_sd3_mmdit(params, attn_mode) + + logger.info("Loading state dict...") + info = mmdit.load_state_dict(mmdit_sd, strict=False, assign=True) + logger.info(f"Loaded MMDiT: {info}") + return mmdit + + +def load_clip_l( + clip_l_path: Optional[str], + dtype: Optional[Union[str, torch.dtype]], + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[Dict] = None, +): + clip_l_sd = None + if clip_l_path is None: + if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_l: remove prefix "text_encoders.clip_l." + logger.info("clip_l is included in the checkpoint") + clip_l_sd = {} + prefix = "text_encoders.clip_l." + for k in list(state_dict.keys()): + if k.startswith(prefix): + clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) + elif clip_l_path is None: + logger.info("clip_l is not included in the checkpoint and clip_l_path is not provided") + return None + + # load clip_l + logger.info("Building CLIP-L") + config = CLIPTextConfig( + vocab_size=49408, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=77, + hidden_act="quick_gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=768, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + with init_empty_weights(): + clip = CLIPTextModelWithProjection(config) + + if clip_l_sd is None: + logger.info(f"Loading state dict from {clip_l_path}") + clip_l_sd = load_safetensors(clip_l_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + + if "text_projection.weight" not in clip_l_sd: + logger.info("Adding text_projection.weight to clip_l_sd") + clip_l_sd["text_projection.weight"] = torch.eye(768, dtype=dtype, device=device) + + info = clip.load_state_dict(clip_l_sd, strict=False, assign=True) + logger.info(f"Loaded CLIP-L: {info}") + return clip + + +def load_clip_g( + clip_g_path: Optional[str], + dtype: Optional[Union[str, torch.dtype]], + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[Dict] = None, +): + clip_g_sd = None + if state_dict is not None: + if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_g: remove prefix "text_encoders.clip_g." + logger.info("clip_g is included in the checkpoint") + clip_g_sd = {} + prefix = "text_encoders.clip_g." + for k in list(state_dict.keys()): + if k.startswith(prefix): + clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) + elif clip_g_path is None: + logger.info("clip_g is not included in the checkpoint and clip_g_path is not provided") + return None + + # load clip_g + logger.info("Building CLIP-G") + config = CLIPTextConfig( + vocab_size=49408, + hidden_size=1280, + intermediate_size=5120, + num_hidden_layers=32, + num_attention_heads=20, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=1280, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + with init_empty_weights(): + clip = CLIPTextModelWithProjection(config) + + if clip_g_sd is None: + logger.info(f"Loading state dict from {clip_g_path}") + clip_g_sd = load_safetensors(clip_g_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = clip.load_state_dict(clip_g_sd, strict=False, assign=True) + logger.info(f"Loaded CLIP-G: {info}") + return clip + + +def load_t5xxl( + t5xxl_path: Optional[str], + dtype: Optional[Union[str, torch.dtype]], + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[Dict] = None, +): + t5xxl_sd = None + if state_dict is not None: + if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: + # found t5xxl: remove prefix "text_encoders.t5xxl." + logger.info("t5xxl is included in the checkpoint") + t5xxl_sd = {} + prefix = "text_encoders.t5xxl." + for k in list(state_dict.keys()): + if k.startswith(prefix): + t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) + elif t5xxl_path is None: + logger.info("t5xxl is not included in the checkpoint and t5xxl_path is not provided") + return None + + return flux_utils_load_t5xxl(t5xxl_path, dtype, device, disable_mmap, state_dict=t5xxl_sd) + + +def load_vae( + vae_path: Optional[str], + vae_dtype: Optional[Union[str, torch.dtype]], + device: Optional[Union[str, torch.device]], + disable_mmap: bool = False, + state_dict: Optional[Dict] = None, +): + vae_sd = {} + if vae_path: + logger.info(f"Loading VAE from {vae_path}...") + vae_sd = load_safetensors(vae_path, device, disable_mmap) + else: + # remove prefix "first_stage_model." + vae_sd = {} + vae_prefix = "first_stage_model." + for k in list(state_dict.keys()): + if k.startswith(vae_prefix): + vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) + + logger.info("Building VAE") + vae = sd3_models.SDVAE(vae_dtype, device) + logger.info("Loading state dict...") + info = vae.load_state_dict(vae_sd) + logger.info(f"Loaded VAE: {info}") + vae.to(device=device, dtype=vae_dtype) # make sure it's in the right device and dtype + return vae + + +# endregion + + +class ModelSamplingDiscreteFlow: + """Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models""" + + def __init__(self, shift=1.0): + self.shift = shift + timesteps = 1000 + self.sigmas = self.sigma(torch.arange(1, timesteps + 1, 1)) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + return sigma * 1000 + + def sigma(self, timestep: torch.Tensor): + timestep = timestep / 1000.0 + if self.shift == 1.0: + return timestep + return self.shift * timestep / (1 + (self.shift - 1) * timestep) + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input - model_output * sigma + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + # assert max_denoise is False, "max_denoise not implemented" + # max_denoise is always True, I'm not sure why it's there + return sigma * noise + (1.0 - sigma) * latent_image diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..078b63fb9b05ede7ca1cc9ed89ecccfdb09a4391 --- /dev/null +++ b/library/sdxl_lpw_stable_diffusion.py @@ -0,0 +1,1275 @@ +# copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py +# and modify to support SD2.x + +import inspect +import re +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from packaging import version +from tqdm import tqdm +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from diffusers import SchedulerMixin, StableDiffusionPipeline +from diffusers.models import AutoencoderKL +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker +from diffusers.utils import logging +from PIL import Image + +from . import ( + sdxl_model_util, + sdxl_train_util, + strategy_base, + strategy_sdxl, + train_util, + sdxl_original_unet, + sdxl_original_control_net, +) + + +try: + from diffusers.utils import PIL_INTERPOLATION +except ImportError: + if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } + else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } + +from comfy.utils import ProgressBar +# ------------------------------------------------------------------------------ + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = pipe.tokenizer(word).input_ids[1:-1] + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, eos_token_id, device): + if not is_sdxl_text_encoder2: + # text_encoder1: same as SD1/2 + enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True) + hidden_states = enc_out["hidden_states"][11] + pool = None + else: + # text_encoder2 + enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True) + hidden_states = enc_out["hidden_states"][-2] # penuultimate layer + # pool = enc_out["text_embeds"] + pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], input_ids, eos_token_id) + hidden_states = hidden_states.to(device) + if pool is not None: + pool = pool.to(device) + return hidden_states, pool + + +def get_unweighted_text_embeddings( + pipe: StableDiffusionPipeline, + text_input: torch.Tensor, + chunk_length: int, + clip_skip: int, + eos: int, + pad: int, + is_sdxl_text_encoder2: bool, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + text_pool = None + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + if pad == eos: # v1 + text_input_chunk[:, -1] = text_input[0, -1] + else: # v2 + for j in range(len(text_input_chunk)): + if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある + text_input_chunk[j, -1] = eos + if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD + text_input_chunk[j, 1] = eos + + text_embedding, current_text_pool = get_hidden_states( + pipe.text_encoder, text_input_chunk, is_sdxl_text_encoder2, eos, pipe.device + ) + if text_pool is None: + text_pool = current_text_pool + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) + else: + text_embeddings, text_pool = get_hidden_states(pipe.text_encoder, text_input, is_sdxl_text_encoder2, eos, pipe.device) + return text_embeddings, text_pool + + +def get_weighted_text_embeddings( + pipe, # : SdxlStableDiffusionLongPromptWeightingPipeline, + prompt: Union[str, List[str]], + uncond_prompt: Optional[Union[str, List[str]]] = None, + max_embeddings_multiples: Optional[int] = 3, + no_boseos_middle: Optional[bool] = False, + skip_parsing: Optional[bool] = False, + skip_weighting: Optional[bool] = False, + clip_skip=None, + is_sdxl_text_encoder2=False, +): + r""" + Prompts can be assigned with local weights using brackets. For example, + prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', + and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. + + Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. + + Args: + pipe (`StableDiffusionPipeline`): + Pipe to provide access to the tokenizer and the text encoder. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + uncond_prompt (`str` or `List[str]`): + The unconditional prompt or prompts for guide the image generation. If unconditional prompt + is provided, the embeddings of prompt and uncond_prompt are concatenated. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + no_boseos_middle (`bool`, *optional*, defaults to `False`): + If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and + ending token in each of the chunk in the middle. + skip_parsing (`bool`, *optional*, defaults to `False`): + Skip the parsing of brackets. + skip_weighting (`bool`, *optional*, defaults to `False`): + Skip the weighting. When the parsing is skipped, it is forced True. + """ + max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] + + if not skip_parsing: + prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2) + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2) + else: + prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids] + prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens = [ + token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids + ] + uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + if uncond_prompt is not None: + max_length = max(max_length, max([len(token) for token in uncond_tokens])) + + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = pipe.tokenizer.bos_token_id + eos = pipe.tokenizer.eos_token_id + pad = pipe.tokenizer.pad_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.tokenizer.model_max_length, + ) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device) + if uncond_prompt is not None: + uncond_tokens, uncond_weights = pad_tokens_and_weights( + uncond_tokens, + uncond_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.tokenizer.model_max_length, + ) + uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) + + # get the embeddings + text_embeddings, text_pool = get_unweighted_text_embeddings( + pipe, + prompt_tokens, + pipe.tokenizer.model_max_length, + clip_skip, + eos, + pad, + is_sdxl_text_encoder2, + no_boseos_middle=no_boseos_middle, + ) + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) + + if uncond_prompt is not None: + uncond_embeddings, uncond_pool = get_unweighted_text_embeddings( + pipe, + uncond_tokens, + pipe.tokenizer.model_max_length, + clip_skip, + eos, + pad, + is_sdxl_text_encoder2, + no_boseos_middle=no_boseos_middle, + ) + uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) + + # assign weights to the prompts and normalize in the sense of mean + # TODO: should we normalize by chunk or in a whole (current implementation)? + if (not skip_parsing) and (not skip_weighting): + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + if uncond_prompt is not None: + return text_embeddings, text_pool, uncond_embeddings, uncond_pool + return text_embeddings, text_pool, None, None + + +def preprocess_image(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask, scale_factor=8): + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + +def prepare_controlnet_image( + image: PIL.Image.Image, + width: int, + height: int, + batch_size: int, + num_images_per_prompt: int, + device: torch.device, + dtype: torch.dtype, + do_classifier_free_guidance: bool = False, + guess_mode: bool = False, +): + if not isinstance(image, torch.Tensor): + if isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + images = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) + image_ = np.array(image_) + image_ = image_[None, :] + images.append(image_) + + image = images + + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + +class SdxlStableDiffusionLongPromptWeightingPipeline: + r""" + Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing + weighting in prompt. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + # if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"): + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: List[CLIPTextModel], + tokenizer: List[CLIPTokenizer], + unet: Union[sdxl_original_unet.SdxlUNet2DConditionModel, sdxl_original_control_net.SdxlControlledUNet], + scheduler: SchedulerMixin, + # clip_skip: int, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + clip_skip: int = 1, + ): + # clip skip is ignored currently + self.tokenizer = tokenizer[0] + self.text_encoder = text_encoder[0] + self.unet = unet + self.scheduler = scheduler + self.safety_checker = safety_checker + self.feature_extractor = feature_extractor + self.requires_safety_checker = requires_safety_checker + self.vae = vae + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.progress_bar = lambda x: tqdm(x, leave=False) + + self.clip_skip = clip_skip + self.tokenizers = tokenizer + self.text_encoders = text_encoder + + # self.__init__additional__() + + # def __init__additional__(self): + # if not hasattr(self, "vae_scale_factor"): + # setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1)) + + def to(self, device=None, dtype=None): + if device is not None: + self.device = device + # self.vae.to(device=self.device) + if dtype is not None: + self.dtype = dtype + + # do not move Text Encoders to device, because Text Encoder should be on CPU + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def check_inputs(self, prompt, height, width, strength, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." + ) + + def get_timesteps(self, num_inference_steps, strength, device, is_text2img): + if is_text2img: + return self.scheduler.timesteps.to(device), num_inference_steps + else: + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(device) + return timesteps, num_inference_steps - t_start + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype)) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + def decode_latents(self, latents): + with torch.no_grad(): + latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents + + # print("post_quant_conv dtype:", self.vae.post_quant_conv.weight.dtype) # torch.float32 + # x = torch.nn.functional.conv2d(latents, self.vae.post_quant_conv.weight.detach(), stride=1, padding=0) + # print("latents dtype:", latents.dtype, "x dtype:", x.dtype) # torch.float32, torch.float16 + # self.vae.to("cpu") + # self.vae.set_use_memory_efficient_attention_xformers(False) + # image = self.vae.decode(latents.to("cpu")).sample + + tensors = self.vae.decode(latents.to(self.vae.dtype)).sample + image = (tensors / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image, tensors + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None): + if image is None: + shape = ( + batch_size, + self.unet.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents, None, None + else: + init_latent_dist = self.vae.encode(image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents + init_latents = torch.cat([init_latents] * batch_size, dim=0) + init_latents_orig = init_latents + shape = init_latents.shape + + # add noise to latents using the timesteps + if device.type == "mps": + noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + noise = torch.randn(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(init_latents, noise, timestep) + return latents, init_latents_orig, noise + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + strength: float = 0.8, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + controlnet: sdxl_original_control_net.SdxlControlNet = None, + controlnet_image=None, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. + `image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + controlnet (`diffusers.ControlNetModel`, *optional*): + A controlnet model to be used for the inference. If not provided, controlnet will be disabled. + controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*): + `Image`, or tensor representing an image batch, to be used as the starting point for the controlnet + inference. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + `None` if cancelled by `is_cancelled_callback`, + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + if controlnet is not None and controlnet_image is None: + raise ValueError("controlnet_image must be provided if controlnet is not None.") + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, strength, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + encoding_strategy: strategy_sdxl.SdxlTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + text_input_ids, text_weights = tokenize_strategy.tokenize_with_weights(prompt) + hidden_states_1, hidden_states_2, text_pool = encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, self.text_encoders, text_input_ids, text_weights + ) + text_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1) + + if do_classifier_free_guidance: + input_ids, weights = tokenize_strategy.tokenize_with_weights(negative_prompt or "") + hidden_states_1, hidden_states_2, uncond_pool = encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, self.text_encoders, input_ids, weights + ) + uncond_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1) + else: + uncond_embeddings = None + uncond_pool = None + + unet_dtype = self.unet.dtype + dtype = unet_dtype + if hasattr(dtype, "itemsize") and dtype.itemsize == 1: # fp8 + dtype = torch.float16 + self.unet.to(dtype) + + # 4. Preprocess image and mask + if isinstance(image, PIL.Image.Image): + image = preprocess_image(image) + if image is not None: + image = image.to(device=self.device, dtype=dtype) + if isinstance(mask_image, PIL.Image.Image): + mask_image = preprocess_mask(mask_image, self.vae_scale_factor) + if mask_image is not None: + mask = mask_image.to(device=self.device, dtype=dtype) + mask = torch.cat([mask] * batch_size * num_images_per_prompt) + else: + mask = None + + # ControlNet is not working yet in SDXL, but keep the code here for future use + if controlnet_image is not None: + controlnet_image = prepare_controlnet_image( + controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False + ) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents, init_latents_orig, noise = self.prepare_latents( + image, + latent_timestep, + batch_size * num_images_per_prompt, + height, + width, + dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # create size embs and concat embeddings for SDXL + orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(device, dtype) + crop_size = torch.zeros_like(orig_size) + target_size = orig_size + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(device, dtype) + + # make conditionings + text_pool = text_pool.to(device, dtype) + if do_classifier_free_guidance: + text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(device, dtype) + + uncond_pool = uncond_pool.to(device, dtype) + cond_vector = torch.cat([text_pool, embs], dim=1).to(dtype) + uncond_vector = torch.cat([uncond_pool, embs], dim=1).to(dtype) + vector_embedding = torch.cat([uncond_vector, cond_vector]) + else: + text_embedding = text_embeddings.to(device, dtype) + vector_embedding = torch.cat([text_pool, embs], dim=1) + + # 8. Denoising loop + comfy_pbar = ProgressBar(total=len(timesteps)) + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # FIXME SD1 ControlNet is not working + + # predict the noise residual + if controlnet is not None: + input_resi_add, mid_add = controlnet(latent_model_input, t, text_embedding, vector_embedding, controlnet_image) + noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding, input_resi_add, mid_add) + else: + noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding) + noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None + comfy_pbar.update(1) + + self.unet.to(unet_dtype) + return latents + + def latents_to_image(self, latents): + # 9. Post-processing + image, tensors = self.decode_latents(latents.to(self.vae.dtype)) + image = self.numpy_to_pil(image) + return image, tensors + + # copy from pil_utils.py + def numpy_to_pil(self, images: np.ndarray) -> Image.Image: + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + def text2img( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: int = 1, + ): + r""" + Function for text-to-image generation. + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + is_cancelled_callback=is_cancelled_callback, + callback_steps=callback_steps, + ) + + def img2img( + self, + image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: int = 1, + ): + r""" + Function for image-to-image generation. + Args: + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. + `image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + image=image, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + strength=strength, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + is_cancelled_callback=is_cancelled_callback, + callback_steps=callback_steps, + ) + + def inpaint( + self, + image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: int = 1, + ): + r""" + Function for inpaint. + Args: + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` + is 1, the denoising process will be run on the masked area for the full number of iterations specified + in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more + noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The reference number of denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. This parameter will be modulated by `strength`, as explained above. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + image=image, + mask_image=mask_image, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + strength=strength, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + is_cancelled_callback=is_cancelled_callback, + callback_steps=callback_steps, + ) diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py new file mode 100644 index 0000000000000000000000000000000000000000..c546dec53db0fa974c43cc68bfa89695e8e67bf8 --- /dev/null +++ b/library/sdxl_model_util.py @@ -0,0 +1,583 @@ +import torch +import safetensors +from accelerate import init_empty_weights +from accelerate.utils.modeling import set_module_tensor_to_device +from safetensors.torch import load_file, save_file +from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer +from typing import List +from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel +from . import model_util +from . import sdxl_original_unet +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +VAE_SCALE_FACTOR = 0.13025 +MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0" + +# Diffusersの設定を読み込むための参照モデル +DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-1.0" + +DIFFUSERS_SDXL_UNET_CONFIG = { + "act_fn": "silu", + "addition_embed_type": "text_time", + "addition_embed_type_num_heads": 64, + "addition_time_embed_dim": 256, + "attention_head_dim": [5, 10, 20], + "block_out_channels": [320, 640, 1280], + "center_input_sample": False, + "class_embed_type": None, + "class_embeddings_concat": False, + "conv_in_kernel": 3, + "conv_out_kernel": 3, + "cross_attention_dim": 2048, + "cross_attention_norm": None, + "down_block_types": ["DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"], + "downsample_padding": 1, + "dual_cross_attention": False, + "encoder_hid_dim": None, + "encoder_hid_dim_type": None, + "flip_sin_to_cos": True, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_only_cross_attention": None, + "mid_block_scale_factor": 1, + "mid_block_type": "UNetMidBlock2DCrossAttn", + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_attention_heads": None, + "num_class_embeds": None, + "only_cross_attention": False, + "out_channels": 4, + "projection_class_embeddings_input_dim": 2816, + "resnet_out_scale_factor": 1.0, + "resnet_skip_time_act": False, + "resnet_time_scale_shift": "default", + "sample_size": 128, + "time_cond_proj_dim": None, + "time_embedding_act_fn": None, + "time_embedding_dim": None, + "time_embedding_type": "positional", + "timestep_post_act": None, + "transformer_layers_per_block": [1, 2, 10], + "up_block_types": ["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"], + "upcast_attention": False, + "use_linear_projection": True, +} + + +def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): + SDXL_KEY_PREFIX = "conditioner.embedders.1.model." + + # SD2のと、基本的には同じ。logit_scaleを後で使うので、それを追加で返す + # logit_scaleはcheckpointの保存時に使用する + def convert_key(key): + # common conversion + key = key.replace(SDXL_KEY_PREFIX + "transformer.", "text_model.encoder.") + key = key.replace(SDXL_KEY_PREFIX, "text_model.") + + if "resblocks" in key: + # resblocks conversion + key = key.replace(".resblocks.", ".layers.") + if ".ln_" in key: + key = key.replace(".ln_", ".layer_norm") + elif ".mlp." in key: + key = key.replace(".c_fc.", ".fc1.") + key = key.replace(".c_proj.", ".fc2.") + elif ".attn.out_proj" in key: + key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") + elif ".attn.in_proj" in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in SD: {key}") + elif ".positional_embedding" in key: + key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") + elif ".text_projection" in key: + key = key.replace("text_model.text_projection", "text_projection.weight") + elif ".logit_scale" in key: + key = None # 後で処理する + elif ".token_embedding" in key: + key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") + elif ".ln_final" in key: + key = key.replace(".ln_final", ".final_layer_norm") + # ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids + elif ".embeddings.position_ids" in key: + key = None # remove this key: position_ids is not used in newer transformers + return key + + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの変換 + for key in keys: + if ".resblocks" in key and ".attn.in_proj_" in key: + # 三つに分割 + values = torch.chunk(checkpoint[key], 3) + + key_suffix = ".weight" if "weight" in key else ".bias" + key_pfx = key.replace(SDXL_KEY_PREFIX + "transformer.resblocks.", "text_model.encoder.layers.") + key_pfx = key_pfx.replace("_weight", "") + key_pfx = key_pfx.replace("_bias", "") + key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") + new_sd[key_pfx + "q_proj" + key_suffix] = values[0] + new_sd[key_pfx + "k_proj" + key_suffix] = values[1] + new_sd[key_pfx + "v_proj" + key_suffix] = values[2] + + # logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す + logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None) + + # temporary workaround for text_projection.weight.weight for Playground-v2 + if "text_projection.weight.weight" in new_sd: + logger.info("convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight") + new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"] + del new_sd["text_projection.weight.weight"] + + return new_sd, logit_scale + + +# load state_dict without allocating new tensors +def _load_state_dict_on_device(model, state_dict, device, dtype=None): + # dtype will use fp32 as default + missing_keys = list(model.state_dict().keys() - state_dict.keys()) + unexpected_keys = list(state_dict.keys() - model.state_dict().keys()) + + # similar to model.load_state_dict() + if not missing_keys and not unexpected_keys: + for k in list(state_dict.keys()): + set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype) + return "" + + # error_msgs + error_msgs: List[str] = [] + if missing_keys: + error_msgs.insert(0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys))) + if unexpected_keys: + error_msgs.insert(0, "Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys))) + + raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))) + + +def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None, disable_mmap=False): + # model_version is reserved for future use + # dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching + + # Load the state dict + if model_util.is_safetensors(ckpt_path): + checkpoint = None + if disable_mmap: + state_dict = safetensors.torch.load(open(ckpt_path, "rb").read()) + else: + try: + state_dict = load_file(ckpt_path, device=map_location) + except: + state_dict = load_file(ckpt_path) # prevent device invalid Error + epoch = None + global_step = None + else: + checkpoint = torch.load(ckpt_path, map_location=map_location) + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + epoch = checkpoint.get("epoch", 0) + global_step = checkpoint.get("global_step", 0) + else: + state_dict = checkpoint + epoch = 0 + global_step = 0 + checkpoint = None + + # U-Net + logger.info("building U-Net") + with init_empty_weights(): + unet = sdxl_original_unet.SdxlUNet2DConditionModel() + + logger.info("loading U-Net from checkpoint") + unet_sd = {} + for k in list(state_dict.keys()): + if k.startswith("model.diffusion_model."): + unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) + info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype) + logger.info(f"U-Net: {info}") + + # Text Encoders + logger.info("building text encoders") + + # Text Encoder 1 is same to Stability AI's SDXL + text_model1_cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=77, + hidden_act="quick_gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=768, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + with init_empty_weights(): + text_model1 = CLIPTextModel._from_config(text_model1_cfg) + + # Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace. + # Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer. + text_model2_cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=1280, + intermediate_size=5120, + num_hidden_layers=32, + num_attention_heads=20, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=1280, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + with init_empty_weights(): + text_model2 = CLIPTextModelWithProjection(text_model2_cfg) + + logger.info("loading text encoders from checkpoint") + te1_sd = {} + te2_sd = {} + for k in list(state_dict.keys()): + if k.startswith("conditioner.embedders.0.transformer."): + te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k) + elif k.startswith("conditioner.embedders.1.model."): + te2_sd[k] = state_dict.pop(k) + + # 最新の transformers では position_ids を含むとエラーになるので削除 / remove position_ids for latest transformers + if "text_model.embeddings.position_ids" in te1_sd: + te1_sd.pop("text_model.embeddings.position_ids") + + info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32 + logger.info(f"text encoder 1: {info1}") + + converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77) + info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32 + logger.info(f"text encoder 2: {info2}") + + # prepare vae + logger.info("building VAE") + vae_config = model_util.create_vae_diffusers_config() + with init_empty_weights(): + vae = AutoencoderKL(**vae_config) + + logger.info("loading VAE from checkpoint") + converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config) + info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype) + logger.info(f"VAE: {info}") + + ckpt_info = (epoch, global_step) if epoch is not None else None + return text_model1, text_model2, vae, unet, logit_scale, ckpt_info + + +def make_unet_conversion_map(): + unet_conversion_map_layer = [] + + for i in range(3): # num_blocks is 3 in sdxl + # loop over downblocks/upblocks + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + # if i > 0: commentout for sdxl + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0.", "norm1."), + ("in_layers.2.", "conv1."), + ("out_layers.0.", "norm2."), + ("out_layers.3.", "conv2."), + ("emb_layers.1.", "time_emb_proj."), + ("skip_connection.", "conv_shortcut."), + ] + + unet_conversion_map = [] + for sd, hf in unet_conversion_map_layer: + if "resnets" in hf: + for sd_res, hf_res in unet_conversion_map_resnet: + unet_conversion_map.append((sd + sd_res, hf + hf_res)) + else: + unet_conversion_map.append((sd, hf)) + + for j in range(2): + hf_time_embed_prefix = f"time_embedding.linear_{j+1}." + sd_time_embed_prefix = f"time_embed.{j*2}." + unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) + + for j in range(2): + hf_label_embed_prefix = f"add_embedding.linear_{j+1}." + sd_label_embed_prefix = f"label_emb.0.{j*2}." + unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) + + unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) + unet_conversion_map.append(("out.0.", "conv_norm_out.")) + unet_conversion_map.append(("out.2.", "conv_out.")) + + return unet_conversion_map + + +def convert_diffusers_unet_state_dict_to_sdxl(du_sd): + unet_conversion_map = make_unet_conversion_map() + + conversion_map = {hf: sd for sd, hf in unet_conversion_map} + return convert_unet_state_dict(du_sd, conversion_map) + + +def convert_unet_state_dict(src_sd, conversion_map): + converted_sd = {} + for src_key, value in src_sd.items(): + # さすがに全部回すのは時間がかかるので右から要素を削りつつprefixを探す + src_key_fragments = src_key.split(".")[:-1] # remove weight/bias + while len(src_key_fragments) > 0: + src_key_prefix = ".".join(src_key_fragments) + "." + if src_key_prefix in conversion_map: + converted_prefix = conversion_map[src_key_prefix] + converted_key = converted_prefix + src_key[len(src_key_prefix) :] + converted_sd[converted_key] = value + break + src_key_fragments.pop(-1) + assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map" + + return converted_sd + + +def convert_sdxl_unet_state_dict_to_diffusers(sd): + unet_conversion_map = make_unet_conversion_map() + + conversion_dict = {sd: hf for sd, hf in unet_conversion_map} + return convert_unet_state_dict(sd, conversion_dict) + + +def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale): + def convert_key(key): + # position_idsの除去 + if ".position_ids" in key: + return None + + # common + key = key.replace("text_model.encoder.", "transformer.") + key = key.replace("text_model.", "") + if "layers" in key: + # resblocks conversion + key = key.replace(".layers.", ".resblocks.") + if ".layer_norm" in key: + key = key.replace(".layer_norm", ".ln_") + elif ".mlp." in key: + key = key.replace(".fc1.", ".c_fc.") + key = key.replace(".fc2.", ".c_proj.") + elif ".self_attn.out_proj" in key: + key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") + elif ".self_attn." in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in DiffUsers model: {key}") + elif ".position_embedding" in key: + key = key.replace("embeddings.position_embedding.weight", "positional_embedding") + elif ".token_embedding" in key: + key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") + elif "text_projection" in key: # no dot in key + key = key.replace("text_projection.weight", "text_projection") + elif "final_layer_norm" in key: + key = key.replace("final_layer_norm", "ln_final") + return key + + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの変換 + for key in keys: + if "layers" in key and "q_proj" in key: + # 三つを結合 + key_q = key + key_k = key.replace("q_proj", "k_proj") + key_v = key.replace("q_proj", "v_proj") + + value_q = checkpoint[key_q] + value_k = checkpoint[key_k] + value_v = checkpoint[key_v] + value = torch.cat([value_q, value_k, value_v]) + + new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") + new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") + new_sd[new_key] = value + + if logit_scale is not None: + new_sd["logit_scale"] = logit_scale + + return new_sd + + +def save_stable_diffusion_checkpoint( + output_file, + text_encoder1, + text_encoder2, + unet, + epochs, + steps, + ckpt_info, + vae, + logit_scale, + metadata, + save_dtype=None, +): + state_dict = {} + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + # Convert the UNet model + update_sd("model.diffusion_model.", unet.state_dict()) + + # Convert the text encoders + update_sd("conditioner.embedders.0.transformer.", text_encoder1.state_dict()) + + text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), logit_scale) + update_sd("conditioner.embedders.1.model.", text_enc2_dict) + + # Convert the VAE + vae_dict = model_util.convert_vae_state_dict(vae.state_dict()) + update_sd("first_stage_model.", vae_dict) + + # Put together new checkpoint + key_count = len(state_dict.keys()) + new_ckpt = {"state_dict": state_dict} + + # epoch and global_step are sometimes not int + if ckpt_info is not None: + epochs += ckpt_info[0] + steps += ckpt_info[1] + + new_ckpt["epoch"] = epochs + new_ckpt["global_step"] = steps + + if model_util.is_safetensors(output_file): + save_file(state_dict, output_file, metadata) + else: + torch.save(new_ckpt, output_file) + + return key_count + + +def save_diffusers_checkpoint( + output_dir, text_encoder1, text_encoder2, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False, save_dtype=None +): + from diffusers import StableDiffusionXLPipeline + + # convert U-Net + unet_sd = unet.state_dict() + du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd) + + diffusers_unet = UNet2DConditionModel(**DIFFUSERS_SDXL_UNET_CONFIG) + if save_dtype is not None: + diffusers_unet.to(save_dtype) + diffusers_unet.load_state_dict(du_unet_sd) + + # create pipeline to save + if pretrained_model_name_or_path is None: + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_SDXL + + scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") + tokenizer1 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") + tokenizer2 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2") + if vae is None: + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") + + # prevent local path from being saved + def remove_name_or_path(model): + if hasattr(model, "config"): + model.config._name_or_path = None + model.config._name_or_path = None + + remove_name_or_path(diffusers_unet) + remove_name_or_path(text_encoder1) + remove_name_or_path(text_encoder2) + remove_name_or_path(scheduler) + remove_name_or_path(tokenizer1) + remove_name_or_path(tokenizer2) + remove_name_or_path(vae) + + pipeline = StableDiffusionXLPipeline( + unet=diffusers_unet, + text_encoder=text_encoder1, + text_encoder_2=text_encoder2, + vae=vae, + scheduler=scheduler, + tokenizer=tokenizer1, + tokenizer_2=tokenizer2, + ) + if save_dtype is not None: + pipeline.to(None, save_dtype) + pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) diff --git a/library/sdxl_original_control_net.py b/library/sdxl_original_control_net.py new file mode 100644 index 0000000000000000000000000000000000000000..b30b7448f5f26630e8e9294c18bb1f60f6bbcfd0 --- /dev/null +++ b/library/sdxl_original_control_net.py @@ -0,0 +1,272 @@ +# some parts are modified from Diffusers library (Apache License 2.0) + +import math +from types import SimpleNamespace +from typing import Any, Optional +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import functional as F +from einops import rearrange +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from . import sdxl_original_unet +from .sdxl_model_util import convert_sdxl_unet_state_dict_to_diffusers, convert_diffusers_unet_state_dict_to_sdxl + + +class ControlNetConditioningEmbedding(nn.Module): + def __init__(self): + super().__init__() + + dims = [16, 32, 96, 256] + + self.conv_in = nn.Conv2d(3, dims[0], kernel_size=3, padding=1) + self.blocks = nn.ModuleList([]) + + for i in range(len(dims) - 1): + channel_in = dims[i] + channel_out = dims[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = nn.Conv2d(dims[-1], 320, kernel_size=3, padding=1) + nn.init.zeros_(self.conv_out.weight) # zero module weight + nn.init.zeros_(self.conv_out.bias) # zero module bias + + def forward(self, x): + x = self.conv_in(x) + x = F.silu(x) + for block in self.blocks: + x = block(x) + x = F.silu(x) + x = self.conv_out(x) + return x + + +class SdxlControlNet(sdxl_original_unet.SdxlUNet2DConditionModel): + def __init__(self, multiplier: Optional[float] = None, **kwargs): + super().__init__(**kwargs) + self.multiplier = multiplier + + # remove unet layers + self.output_blocks = nn.ModuleList([]) + del self.out + + self.controlnet_cond_embedding = ControlNetConditioningEmbedding() + + dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280] + self.controlnet_down_blocks = nn.ModuleList([]) + for dim in dims: + self.controlnet_down_blocks.append(nn.Conv2d(dim, dim, kernel_size=1)) + nn.init.zeros_(self.controlnet_down_blocks[-1].weight) # zero module weight + nn.init.zeros_(self.controlnet_down_blocks[-1].bias) # zero module bias + + self.controlnet_mid_block = nn.Conv2d(1280, 1280, kernel_size=1) + nn.init.zeros_(self.controlnet_mid_block.weight) # zero module weight + nn.init.zeros_(self.controlnet_mid_block.bias) # zero module bias + + def init_from_unet(self, unet: sdxl_original_unet.SdxlUNet2DConditionModel): + unet_sd = unet.state_dict() + unet_sd = {k: v for k, v in unet_sd.items() if not k.startswith("out")} + sd = super().state_dict() + sd.update(unet_sd) + info = super().load_state_dict(sd, strict=True, assign=True) + return info + + def load_state_dict(self, state_dict: dict, strict: bool = True, assign: bool = True) -> Any: + # convert state_dict to SAI format + unet_sd = {} + for k in list(state_dict.keys()): + if not k.startswith("controlnet_"): + unet_sd[k] = state_dict.pop(k) + unet_sd = convert_diffusers_unet_state_dict_to_sdxl(unet_sd) + state_dict.update(unet_sd) + super().load_state_dict(state_dict, strict=strict, assign=assign) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + # convert state_dict to Diffusers format + state_dict = super().state_dict(destination, prefix, keep_vars) + control_net_sd = {} + for k in list(state_dict.keys()): + if k.startswith("controlnet_"): + control_net_sd[k] = state_dict.pop(k) + state_dict = convert_sdxl_unet_state_dict_to_diffusers(state_dict) + state_dict.update(control_net_sd) + return state_dict + + def forward( + self, + x: torch.Tensor, + timesteps: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + cond_image: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + # broadcast timesteps to batch dimension + timesteps = timesteps.expand(x.shape[0]) + + t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) + t_emb = t_emb.to(x.dtype) + emb = self.time_embed(t_emb) + + assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" + assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" + emb = emb + self.label_emb(y) + + def call_module(module, h, emb, context): + x = h + for layer in module: + if isinstance(layer, sdxl_original_unet.ResnetBlock2D): + x = layer(x, emb) + elif isinstance(layer, sdxl_original_unet.Transformer2DModel): + x = layer(x, context) + else: + x = layer(x) + return x + + h = x + multiplier = self.multiplier if self.multiplier is not None else 1.0 + hs = [] + for i, module in enumerate(self.input_blocks): + h = call_module(module, h, emb, context) + if i == 0: + h = self.controlnet_cond_embedding(cond_image) + h + hs.append(self.controlnet_down_blocks[i](h) * multiplier) + + h = call_module(self.middle_block, h, emb, context) + h = self.controlnet_mid_block(h) * multiplier + + return hs, h + + +class SdxlControlledUNet(sdxl_original_unet.SdxlUNet2DConditionModel): + """ + This class is for training purpose only. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs): + # broadcast timesteps to batch dimension + timesteps = timesteps.expand(x.shape[0]) + + hs = [] + t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) + t_emb = t_emb.to(x.dtype) + emb = self.time_embed(t_emb) + + assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" + assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" + emb = emb + self.label_emb(y) + + def call_module(module, h, emb, context): + x = h + for layer in module: + if isinstance(layer, sdxl_original_unet.ResnetBlock2D): + x = layer(x, emb) + elif isinstance(layer, sdxl_original_unet.Transformer2DModel): + x = layer(x, context) + else: + x = layer(x) + return x + + h = x + for module in self.input_blocks: + h = call_module(module, h, emb, context) + hs.append(h) + + h = call_module(self.middle_block, h, emb, context) + h = h + mid_add + + for module in self.output_blocks: + resi = hs.pop() + input_resi_add.pop() + h = torch.cat([h, resi], dim=1) + h = call_module(module, h, emb, context) + + h = h.type(x.dtype) + h = call_module(self.out, h, emb, context) + + return h + + +if __name__ == "__main__": + import time + + logger.info("create unet") + unet = SdxlControlledUNet() + unet.to("cuda", torch.bfloat16) + unet.set_use_sdpa(True) + unet.set_gradient_checkpointing(True) + unet.train() + + logger.info("create control_net") + control_net = SdxlControlNet() + control_net.to("cuda") + control_net.set_use_sdpa(True) + control_net.set_gradient_checkpointing(True) + control_net.train() + + logger.info("Initialize control_net from unet") + control_net.init_from_unet(unet) + + unet.requires_grad_(False) + control_net.requires_grad_(True) + + # 使用メモリ量確認用の疑似学習ループ + logger.info("preparing optimizer") + + # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working + + import bitsandbytes + + optimizer = bitsandbytes.adam.Adam8bit(control_net.parameters(), lr=1e-3) # not working + # optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2 + # optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2 + + # import transformers + # optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2 + + scaler = torch.cuda.amp.GradScaler(enabled=True) + + logger.info("start training") + steps = 10 + batch_size = 1 + + for step in range(steps): + logger.info(f"step {step}") + if step == 1: + time_start = time.perf_counter() + + x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024 + t = torch.randint(low=0, high=1000, size=(batch_size,), device="cuda") + txt = torch.randn(batch_size, 77, 2048).cuda() + vector = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda() + cond_img = torch.rand(batch_size, 3, 1024, 1024).cuda() + + with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): + input_resi_add, mid_add = control_net(x, t, txt, vector, cond_img) + output = unet(x, t, txt, vector, input_resi_add, mid_add) + target = torch.randn_like(output) + loss = torch.nn.functional.mse_loss(output, target) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad(set_to_none=True) + + time_end = time.perf_counter() + logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps") + + logger.info("finish training") + sd = control_net.state_dict() + + from safetensors.torch import save_file + + save_file(sd, r"E:\Work\SD\Tmp\sdxl\ctrl\control_net.safetensors") diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..f4c2cc069f981262240c6662446640baf5d238c7 --- /dev/null +++ b/library/sdxl_original_unet.py @@ -0,0 +1,1292 @@ +# Diffusersのコードをベースとした sd_xl_baseのU-Net +# state dictの形式をSDXLに合わせてある + +""" + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + adm_in_channels: 2816 + num_classes: sequential + use_checkpoint: True + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2] + num_res_blocks: 2 + channel_mult: [1, 2, 4] + num_head_channels: 64 + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16 + context_dim: 2048 + spatial_transformer_attn_type: softmax-xformers + legacy: False +""" + +import math +from types import SimpleNamespace +from typing import Any, Optional +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import functional as F +from einops import rearrange +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +IN_CHANNELS: int = 4 +OUT_CHANNELS: int = 4 +ADM_IN_CHANNELS: int = 2816 +CONTEXT_DIM: int = 2048 +MODEL_CHANNELS: int = 320 +TIME_EMBED_DIM = 320 * 4 + +USE_REENTRANT = True + +# region memory efficient attention + +# FlashAttentionを使うCrossAttention +# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py +# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE + +# constants + +EPSILON = 1e-6 + +# helper functions + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +# flash attention forwards and backwards + +# https://arxiv.org/abs/2205.14135 + + +class FlashAttentionFunction(torch.autograd.Function): + @staticmethod + @torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """Algorithm 2 in the paper""" + + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) + + scale = q.shape[-1] ** -0.5 + + if not exists(mask): + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, "b n -> b 1 1 n") + mask = mask.split(q_bucket_size, dim=-1) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale + + if exists(row_mask): + attn_weights.masked_fill_(~row_mask, max_neg_value) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) + + if exists(row_mask): + exp_weights.masked_fill_(~row_mask, 0.0) + + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) + + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + + exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc) + + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) + + new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums + + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) + + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) + + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + + return o + + @staticmethod + @torch.no_grad() + def backward(ctx, do): + """Algorithm 4 in the paper""" + + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors + + device = q.device + + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + exp_attn_weights = torch.exp(attn_weights - mc) + + if exists(row_mask): + exp_attn_weights.masked_fill_(~row_mask, 0.0) + + p = exp_attn_weights / lc + + dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc) + dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc) + + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) + + dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc) + dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc) + + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) + + return dq, dk, dv, None, None, None, None + + +# endregion + + +def get_parameter_dtype(parameter: torch.nn.Module): + return next(parameter.parameters()).dtype + + +def get_parameter_device(parameter: torch.nn.Module): + return next(parameter.parameters()).device + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings: flipped from Diffusers original ver because always flip_sin_to_cos=True + emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +# Deep Shrink: We do not common this function, because minimize dependencies. +def resize_like(x, target, mode="bicubic", align_corners=False): + org_dtype = x.dtype + if org_dtype == torch.bfloat16: + x = x.to(torch.float32) + + if x.shape[-2:] != target.shape[-2:]: + if mode == "nearest": + x = F.interpolate(x, size=target.shape[-2:], mode=mode) + else: + x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners) + + if org_dtype == torch.bfloat16: + x = x.to(org_dtype) + return x + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + if self.weight.dtype != torch.float32: + return super().forward(x) + return super().forward(x.float()).type(x.dtype) + + +class ResnetBlock2D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.in_layers = nn.Sequential( + GroupNorm32(32, in_channels), + nn.SiLU(), + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), + ) + + self.emb_layers = nn.Sequential(nn.SiLU(), nn.Linear(TIME_EMBED_DIM, out_channels)) + + self.out_layers = nn.Sequential( + GroupNorm32(32, out_channels), + nn.SiLU(), + nn.Identity(), # to make state_dict compatible with original model + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), + ) + + if in_channels != out_channels: + self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + else: + self.skip_connection = nn.Identity() + + self.gradient_checkpointing = False + + def forward_body(self, x, emb): + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + h = h + emb_out[:, :, None, None] + h = self.out_layers(h) + x = self.skip_connection(x) + return x + h + + def forward(self, x, emb): + if self.training and self.gradient_checkpointing: + # logger.info("ResnetBlock2D: gradient_checkpointing") + + def create_custom_forward(func): + def custom_forward(*inputs): + return func(*inputs) + + return custom_forward + + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, emb, use_reentrant=USE_REENTRANT) + else: + x = self.forward_body(x, emb) + + return x + + +class Downsample2D(nn.Module): + def __init__(self, channels, out_channels): + super().__init__() + + self.channels = channels + self.out_channels = out_channels + + self.op = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1) + + self.gradient_checkpointing = False + + def forward_body(self, hidden_states): + assert hidden_states.shape[1] == self.channels + hidden_states = self.op(hidden_states) + + return hidden_states + + def forward(self, hidden_states): + if self.training and self.gradient_checkpointing: + # logger.info("Downsample2D: gradient_checkpointing") + + def create_custom_forward(func): + def custom_forward(*inputs): + return func(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.forward_body), hidden_states, use_reentrant=USE_REENTRANT + ) + else: + hidden_states = self.forward_body(hidden_states) + + return hidden_states + + +class CrossAttention(nn.Module): + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + upcast_attention: bool = False, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + # no dropout here + + self.use_memory_efficient_attention_xformers = False + self.use_memory_efficient_attention_mem_eff = False + self.use_sdpa = False + + def set_use_memory_efficient_attention(self, xformers, mem_eff): + self.use_memory_efficient_attention_xformers = xformers + self.use_memory_efficient_attention_mem_eff = mem_eff + + def set_use_sdpa(self, sdpa): + self.use_sdpa = sdpa + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def forward(self, hidden_states, context=None, mask=None): + if self.use_memory_efficient_attention_xformers: + return self.forward_memory_efficient_xformers(hidden_states, context, mask) + if self.use_memory_efficient_attention_mem_eff: + return self.forward_memory_efficient_mem_eff(hidden_states, context, mask) + if self.use_sdpa: + return self.forward_sdpa(hidden_states, context, mask) + + query = self.to_q(hidden_states) + context = context if context is not None else hidden_states + key = self.to_k(context) + value = self.to_v(context) + + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + hidden_states = self._attention(query, key, value) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + # hidden_states = self.to_out[1](hidden_states) # no dropout + return hidden_states + + def _attention(self, query, key, value): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + # TODO support Hypernetworks + def forward_memory_efficient_xformers(self, x, context=None, mask=None): + import xformers.ops + + h = self.heads + q_in = self.to_q(x) + context = context if context is not None else x + context = context.to(x.dtype) + k_in = self.to_k(context) + v_in = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる + del q, k, v + + out = rearrange(out, "b n h d -> b n (h d)", h=h) + + out = self.to_out[0](out) + return out + + def forward_memory_efficient_mem_eff(self, x, context=None, mask=None): + flash_func = FlashAttentionFunction + + q_bucket_size = 512 + k_bucket_size = 1024 + + h = self.heads + q = self.to_q(x) + context = context if context is not None else x + context = context.to(x.dtype) + k = self.to_k(context) + v = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + out = self.to_out[0](out) + return out + + def forward_sdpa(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) + context = context if context is not None else x + context = context.to(x.dtype) + k_in = self.to_k(context) + v_in = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + + out = rearrange(out, "b h n d -> b n (h d)", h=h) + + out = self.to_out[0](out) + return out + + +# feedforward +class GEGLU(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + ): + super().__init__() + inner_dim = int(dim * 4) # mult is always 4 + + self.net = nn.ModuleList([]) + # project in + self.net.append(GEGLU(dim, inner_dim)) + # project dropout + self.net.append(nn.Identity()) # nn.Dropout(0)) # dummy for dropout with 0 + # project out + self.net.append(nn.Linear(inner_dim, dim)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False + ): + super().__init__() + + self.gradient_checkpointing = False + + # 1. Self-Attn + self.attn1 = CrossAttention( + query_dim=dim, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + upcast_attention=upcast_attention, + ) + self.ff = FeedForward(dim) + + # 2. Cross-Attn + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + upcast_attention=upcast_attention, + ) + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim) + + def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool): + self.attn1.set_use_memory_efficient_attention(xformers, mem_eff) + self.attn2.set_use_memory_efficient_attention(xformers, mem_eff) + + def set_use_sdpa(self, sdpa: bool): + self.attn1.set_use_sdpa(sdpa) + self.attn2.set_use_sdpa(sdpa) + + def forward_body(self, hidden_states, context=None, timestep=None): + # 1. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + hidden_states = self.attn1(norm_hidden_states) + hidden_states + + # 2. Cross-Attention + norm_hidden_states = self.norm2(hidden_states) + hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states + + # 3. Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + return hidden_states + + def forward(self, hidden_states, context=None, timestep=None): + if self.training and self.gradient_checkpointing: + # logger.info("BasicTransformerBlock: checkpointing") + + def create_custom_forward(func): + def custom_forward(*inputs): + return func(*inputs) + + return custom_forward + + output = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.forward_body), hidden_states, context, timestep, use_reentrant=USE_REENTRANT + ) + else: + output = self.forward_body(hidden_states, context, timestep) + + return output + + +class Transformer2DModel(nn.Module): + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + use_linear_projection: bool = False, + upcast_attention: bool = False, + num_transformer_layers: int = 1, + ): + super().__init__() + self.in_channels = in_channels + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.use_linear_projection = use_linear_projection + + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + # self.norm = GroupNorm32(32, in_channels, eps=1e-6, affine=True) + + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + blocks = [] + for _ in range(num_transformer_layers): + blocks.append( + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + ) + ) + + self.transformer_blocks = nn.ModuleList(blocks) + + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + self.gradient_checkpointing = False + + def set_use_memory_efficient_attention(self, xformers, mem_eff): + for transformer in self.transformer_blocks: + transformer.set_use_memory_efficient_attention(xformers, mem_eff) + + def set_use_sdpa(self, sdpa): + for transformer in self.transformer_blocks: + transformer.set_use_sdpa(sdpa) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None): + # 1. Input + batch, _, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep) + + # 3. Output + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + + return output + + +class Upsample2D(nn.Module): + def __init__(self, channels, out_channels): + super().__init__() + self.channels = channels + self.out_channels = out_channels + self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward_body(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch + # https://github.com/pytorch/pytorch/issues/86679 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + hidden_states = self.conv(hidden_states) + + return hidden_states + + def forward(self, hidden_states, output_size=None): + if self.training and self.gradient_checkpointing: + # logger.info("Upsample2D: gradient_checkpointing") + + def create_custom_forward(func): + def custom_forward(*inputs): + return func(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.forward_body), hidden_states, output_size, use_reentrant=USE_REENTRANT + ) + else: + hidden_states = self.forward_body(hidden_states, output_size) + + return hidden_states + + +class SdxlUNet2DConditionModel(nn.Module): + _supports_gradient_checkpointing = True + + def __init__( + self, + **kwargs, + ): + super().__init__() + + self.in_channels = IN_CHANNELS + self.out_channels = OUT_CHANNELS + self.model_channels = MODEL_CHANNELS + self.time_embed_dim = TIME_EMBED_DIM + self.adm_in_channels = ADM_IN_CHANNELS + + self.gradient_checkpointing = False + # self.sample_size = sample_size + + # time embedding + self.time_embed = nn.Sequential( + nn.Linear(self.model_channels, self.time_embed_dim), + nn.SiLU(), + nn.Linear(self.time_embed_dim, self.time_embed_dim), + ) + + # label embedding + self.label_emb = nn.Sequential( + nn.Sequential( + nn.Linear(self.adm_in_channels, self.time_embed_dim), + nn.SiLU(), + nn.Linear(self.time_embed_dim, self.time_embed_dim), + ) + ) + + # input + self.input_blocks = nn.ModuleList( + [ + nn.Sequential( + nn.Conv2d(self.in_channels, self.model_channels, kernel_size=3, padding=(1, 1)), + ) + ] + ) + + # level 0 + for i in range(2): + layers = [ + ResnetBlock2D( + in_channels=1 * self.model_channels, + out_channels=1 * self.model_channels, + ), + ] + self.input_blocks.append(nn.ModuleList(layers)) + + self.input_blocks.append( + nn.Sequential( + Downsample2D( + channels=1 * self.model_channels, + out_channels=1 * self.model_channels, + ), + ) + ) + + # level 1 + for i in range(2): + layers = [ + ResnetBlock2D( + in_channels=(1 if i == 0 else 2) * self.model_channels, + out_channels=2 * self.model_channels, + ), + Transformer2DModel( + num_attention_heads=2 * self.model_channels // 64, + attention_head_dim=64, + in_channels=2 * self.model_channels, + num_transformer_layers=2, + use_linear_projection=True, + cross_attention_dim=2048, + ), + ] + self.input_blocks.append(nn.ModuleList(layers)) + + self.input_blocks.append( + nn.Sequential( + Downsample2D( + channels=2 * self.model_channels, + out_channels=2 * self.model_channels, + ), + ) + ) + + # level 2 + for i in range(2): + layers = [ + ResnetBlock2D( + in_channels=(2 if i == 0 else 4) * self.model_channels, + out_channels=4 * self.model_channels, + ), + Transformer2DModel( + num_attention_heads=4 * self.model_channels // 64, + attention_head_dim=64, + in_channels=4 * self.model_channels, + num_transformer_layers=10, + use_linear_projection=True, + cross_attention_dim=2048, + ), + ] + self.input_blocks.append(nn.ModuleList(layers)) + + # mid + self.middle_block = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=4 * self.model_channels, + out_channels=4 * self.model_channels, + ), + Transformer2DModel( + num_attention_heads=4 * self.model_channels // 64, + attention_head_dim=64, + in_channels=4 * self.model_channels, + num_transformer_layers=10, + use_linear_projection=True, + cross_attention_dim=2048, + ), + ResnetBlock2D( + in_channels=4 * self.model_channels, + out_channels=4 * self.model_channels, + ), + ] + ) + + # output + self.output_blocks = nn.ModuleList([]) + + # level 2 + for i in range(3): + layers = [ + ResnetBlock2D( + in_channels=4 * self.model_channels + (4 if i <= 1 else 2) * self.model_channels, + out_channels=4 * self.model_channels, + ), + Transformer2DModel( + num_attention_heads=4 * self.model_channels // 64, + attention_head_dim=64, + in_channels=4 * self.model_channels, + num_transformer_layers=10, + use_linear_projection=True, + cross_attention_dim=2048, + ), + ] + if i == 2: + layers.append( + Upsample2D( + channels=4 * self.model_channels, + out_channels=4 * self.model_channels, + ) + ) + + self.output_blocks.append(nn.ModuleList(layers)) + + # level 1 + for i in range(3): + layers = [ + ResnetBlock2D( + in_channels=2 * self.model_channels + (4 if i == 0 else (2 if i == 1 else 1)) * self.model_channels, + out_channels=2 * self.model_channels, + ), + Transformer2DModel( + num_attention_heads=2 * self.model_channels // 64, + attention_head_dim=64, + in_channels=2 * self.model_channels, + num_transformer_layers=2, + use_linear_projection=True, + cross_attention_dim=2048, + ), + ] + if i == 2: + layers.append( + Upsample2D( + channels=2 * self.model_channels, + out_channels=2 * self.model_channels, + ) + ) + + self.output_blocks.append(nn.ModuleList(layers)) + + # level 0 + for i in range(3): + layers = [ + ResnetBlock2D( + in_channels=1 * self.model_channels + (2 if i == 0 else 1) * self.model_channels, + out_channels=1 * self.model_channels, + ), + ] + + self.output_blocks.append(nn.ModuleList(layers)) + + # output + self.out = nn.ModuleList( + [GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)] + ) + + # region diffusers compatibility + def prepare_config(self): + self.config = SimpleNamespace() + + @property + def dtype(self) -> torch.dtype: + # `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + return get_parameter_dtype(self) + + @property + def device(self) -> torch.device: + # `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device). + return get_parameter_device(self) + + def set_attention_slice(self, slice_size): + raise NotImplementedError("Attention slicing is not supported for this model.") + + def is_gradient_checkpointing(self) -> bool: + return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + self.set_gradient_checkpointing(value=True) + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.set_gradient_checkpointing(value=False) + + def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None: + blocks = self.input_blocks + [self.middle_block] + self.output_blocks + for block in blocks: + for module in block: + if hasattr(module, "set_use_memory_efficient_attention"): + # logger.info(module.__class__.__name__) + module.set_use_memory_efficient_attention(xformers, mem_eff) + + def set_use_sdpa(self, sdpa: bool) -> None: + blocks = self.input_blocks + [self.middle_block] + self.output_blocks + for block in blocks: + for module in block: + if hasattr(module, "set_use_sdpa"): + module.set_use_sdpa(sdpa) + + def set_gradient_checkpointing(self, value=False): + blocks = self.input_blocks + [self.middle_block] + self.output_blocks + for block in blocks: + for module in block.modules(): + if hasattr(module, "gradient_checkpointing"): + # logger.info(f{module.__class__.__name__} {module.gradient_checkpointing} -> {value}") + module.gradient_checkpointing = value + + # endregion + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + # broadcast timesteps to batch dimension + timesteps = timesteps.expand(x.shape[0]) + + hs = [] + t_emb = get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) # , repeat_only=False) + t_emb = t_emb.to(x.dtype) + emb = self.time_embed(t_emb) + + assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" + assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" + # assert x.dtype == self.dtype + emb = emb + self.label_emb(y) + + def call_module(module, h, emb, context): + x = h + for layer in module: + # logger.info(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None) + if isinstance(layer, ResnetBlock2D): + x = layer(x, emb) + elif isinstance(layer, Transformer2DModel): + x = layer(x, context) + else: + x = layer(x) + return x + + # h = x.type(self.dtype) + h = x + + for module in self.input_blocks: + h = call_module(module, h, emb, context) + hs.append(h) + + h = call_module(self.middle_block, h, emb, context) + + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = call_module(module, h, emb, context) + + h = h.type(x.dtype) + h = call_module(self.out, h, emb, context) + + return h + + +class InferSdxlUNet2DConditionModel: + def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs): + self.delegate = original_unet + + # override original model's forward method: because forward is not called by `__call__` + # overriding `__call__` is not enough, because nn.Module.forward has a special handling + self.delegate.forward = self.forward + + # Deep Shrink + self.ds_depth_1 = None + self.ds_depth_2 = None + self.ds_timesteps_1 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + + # call original model's methods + def __getattr__(self, name): + return getattr(self.delegate, name) + + def __call__(self, *args, **kwargs): + return self.delegate(*args, **kwargs) + + def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): + if ds_depth_1 is None: + logger.info("Deep Shrink is disabled.") + self.ds_depth_1 = None + self.ds_timesteps_1 = None + self.ds_depth_2 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + else: + logger.info( + f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" + ) + self.ds_depth_1 = ds_depth_1 + self.ds_timesteps_1 = ds_timesteps_1 + self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1 + self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 + self.ds_ratio = ds_ratio + + def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs): + r""" + current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink and ControlNet. + """ + _self = self.delegate + + # broadcast timesteps to batch dimension + timesteps = timesteps.expand(x.shape[0]) + + hs = [] + t_emb = get_timestep_embedding(timesteps, _self.model_channels, downscale_freq_shift=0) # , repeat_only=False) + t_emb = t_emb.to(x.dtype) + emb = _self.time_embed(t_emb) + + assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" + assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" + # assert x.dtype == _self.dtype + emb = emb + _self.label_emb(y) + + def call_module(module, h, emb, context): + x = h + for layer in module: + # print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None) + if isinstance(layer, ResnetBlock2D): + x = layer(x, emb) + elif isinstance(layer, Transformer2DModel): + x = layer(x, context) + else: + x = layer(x) + return x + + # h = x.type(self.dtype) + h = x + + for depth, module in enumerate(_self.input_blocks): + # Deep Shrink + if self.ds_depth_1 is not None: + if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or ( + self.ds_depth_2 is not None + and depth == self.ds_depth_2 + and timesteps[0] < self.ds_timesteps_1 + and timesteps[0] >= self.ds_timesteps_2 + ): + # print("downsample", h.shape, self.ds_ratio) + org_dtype = h.dtype + if org_dtype == torch.bfloat16: + h = h.to(torch.float32) + h = F.interpolate(h, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype) + + h = call_module(module, h, emb, context) + hs.append(h) + + h = call_module(_self.middle_block, h, emb, context) + if mid_add is not None: + h = h + mid_add + + for module in _self.output_blocks: + # Deep Shrink + if self.ds_depth_1 is not None: + if hs[-1].shape[-2:] != h.shape[-2:]: + # print("upsample", h.shape, hs[-1].shape) + h = resize_like(h, hs[-1]) + + resi = hs.pop() + if input_resi_add is not None: + resi = resi + input_resi_add.pop() + + h = torch.cat([h, resi], dim=1) + h = call_module(module, h, emb, context) + + # Deep Shrink: in case of depth 0 + if self.ds_depth_1 == 0 and h.shape[-2:] != x.shape[-2:]: + # print("upsample", h.shape, x.shape) + h = resize_like(h, x) + + h = h.type(x.dtype) + h = call_module(_self.out, h, emb, context) + + return h + + +if __name__ == "__main__": + import time + + logger.info("create unet") + unet = SdxlUNet2DConditionModel() + + unet.to("cuda") + unet.set_use_memory_efficient_attention(True, False) + unet.set_gradient_checkpointing(True) + unet.train() + + # 使用メモリ量確認用の疑似学習ループ + logger.info("preparing optimizer") + + # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working + + # import bitsandbytes + # optimizer = bitsandbytes.adam.Adam8bit(unet.parameters(), lr=1e-3) # not working + # optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2 + # optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2 + + import transformers + + optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2 + + scaler = torch.cuda.amp.GradScaler(enabled=True) + + logger.info("start training") + steps = 10 + batch_size = 1 + + for step in range(steps): + logger.info(f"step {step}") + if step == 1: + time_start = time.perf_counter() + + x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024 + t = torch.randint(low=0, high=10, size=(batch_size,), device="cuda") + ctx = torch.randn(batch_size, 77, 2048).cuda() + y = torch.randn(batch_size, ADM_IN_CHANNELS).cuda() + + with torch.cuda.amp.autocast(enabled=True): + output = unet(x, t, ctx, y) + target = torch.randn_like(output) + loss = torch.nn.functional.mse_loss(output, target) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad(set_to_none=True) + + time_end = time.perf_counter() + logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps") diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py new file mode 100644 index 0000000000000000000000000000000000000000..4004784f2a8eca8452307386e2c5ca014dd69450 --- /dev/null +++ b/library/sdxl_train_util.py @@ -0,0 +1,381 @@ +import argparse +import math +import os +from typing import Optional + +import torch +from .device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate import init_empty_weights +from tqdm import tqdm +from transformers import CLIPTokenizer +from . import model_util, sdxl_model_util, train_util, sdxl_original_unet +from .sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +TOKENIZER1_PATH = "openai/clip-vit-large-patch14" +TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + +# DEFAULT_NOISE_OFFSET = 0.0357 + + +def load_target_model(args, accelerator, model_version: str, weight_dtype): + model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16 + for pi in range(accelerator.state.num_processes): + if pi == accelerator.state.local_process_index: + logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") + + ( + load_stable_diffusion_format, + text_encoder1, + text_encoder2, + vae, + unet, + logit_scale, + ckpt_info, + ) = _load_target_model( + args.pretrained_model_name_or_path, + args.vae, + model_version, + weight_dtype, + accelerator.device if args.lowram else "cpu", + model_dtype, + args.disable_mmap_load_safetensors, + ) + + # work on low-ram device + if args.lowram: + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + unet.to(accelerator.device) + vae.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + accelerator.wait_for_everyone() + + return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info + + +def _load_target_model( + name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None, disable_mmap=False +): + # model_dtype only work with full fp16/bf16 + name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path + load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers + + if load_stable_diffusion_format: + logger.info(f"load StableDiffusion checkpoint: {name_or_path}") + ( + text_encoder1, + text_encoder2, + vae, + unet, + logit_scale, + ckpt_info, + ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype, disable_mmap) + else: + # Diffusers model is loaded to CPU + from diffusers import StableDiffusionXLPipeline + + variant = "fp16" if weight_dtype == torch.float16 else None + logger.info(f"load Diffusers pretrained models: {name_or_path}, variant={variant}") + try: + try: + pipe = StableDiffusionXLPipeline.from_pretrained( + name_or_path, torch_dtype=model_dtype, variant=variant, tokenizer=None + ) + except EnvironmentError as ex: + if variant is not None: + logger.info("try to load fp32 model") + pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None) + else: + raise ex + except EnvironmentError as ex: + logger.error( + f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" + ) + raise ex + + text_encoder1 = pipe.text_encoder + text_encoder2 = pipe.text_encoder_2 + + # convert to fp32 for cache text_encoders outputs + if text_encoder1.dtype != torch.float32: + text_encoder1 = text_encoder1.to(dtype=torch.float32) + if text_encoder2.dtype != torch.float32: + text_encoder2 = text_encoder2.to(dtype=torch.float32) + + vae = pipe.vae + unet = pipe.unet + del pipe + + # Diffusers U-Net to original U-Net + state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict()) + with init_empty_weights(): + unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet + sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype) + logger.info("U-Net converted to original U-Net") + + logit_scale = None + ckpt_info = None + + # VAEを読み込む + if vae_path is not None: + vae = model_util.load_vae(vae_path, weight_dtype) + logger.info("additional VAE loaded") + + return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info + + +def load_tokenizers(args: argparse.Namespace): + logger.info("prepare tokenizers") + + original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH] + tokeniers = [] + for i, original_path in enumerate(original_paths): + tokenizer: CLIPTokenizer = None + if args.tokenizer_cache_dir: + local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) + if os.path.exists(local_tokenizer_path): + logger.info(f"load tokenizer from cache: {local_tokenizer_path}") + tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) + + if tokenizer is None: + tokenizer = CLIPTokenizer.from_pretrained(original_path) + + if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): + logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") + tokenizer.save_pretrained(local_tokenizer_path) + + if i == 1: + tokenizer.pad_token_id = 0 # fix pad token id to make same as open clip tokenizer + + tokeniers.append(tokenizer) + + if hasattr(args, "max_token_length") and args.max_token_length is not None: + logger.info(f"update token length: {args.max_token_length}") + + return tokeniers + + +def match_mixed_precision(args, weight_dtype): + if args.full_fp16: + assert ( + weight_dtype == torch.float16 + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + return weight_dtype + elif args.full_bf16: + assert ( + weight_dtype == torch.bfloat16 + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + return weight_dtype + else: + return None + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=timesteps.device + ) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def get_timestep_embedding(x, outdim): + assert len(x.shape) == 2 + b, dims = x.shape[0], x.shape[1] + x = torch.flatten(x) + emb = timestep_embedding(x, outdim) + emb = torch.reshape(emb, (b, dims * outdim)) + return emb + + +def get_size_embeddings(orig_size, crop_size, target_size, device): + emb1 = get_timestep_embedding(orig_size, 256) + emb2 = get_timestep_embedding(crop_size, 256) + emb3 = get_timestep_embedding(target_size, 256) + vector = torch.cat([emb1, emb2, emb3], dim=1).to(device) + return vector + + +def save_sd_model_on_train_end( + args: argparse.Namespace, + src_path: str, + save_stable_diffusion_format: bool, + use_safetensors: bool, + save_dtype: torch.dtype, + epoch: int, + global_step: int, + text_encoder1, + text_encoder2, + unet, + vae, + logit_scale, + ckpt_info, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True) + sdxl_model_util.save_stable_diffusion_checkpoint( + ckpt_file, + text_encoder1, + text_encoder2, + unet, + epoch_no, + global_step, + ckpt_info, + vae, + logit_scale, + sai_metadata, + save_dtype, + ) + + def diffusers_saver(out_dir): + sdxl_model_util.save_diffusers_checkpoint( + out_dir, + text_encoder1, + text_encoder2, + unet, + src_path, + vae, + use_safetensors=use_safetensors, + save_dtype=save_dtype, + ) + + train_util.save_sd_model_on_train_end_common( + args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver + ) + + +# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している +# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 +def save_sd_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + src_path, + save_stable_diffusion_format: bool, + use_safetensors: bool, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + text_encoder1, + text_encoder2, + unet, + vae, + logit_scale, + ckpt_info, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True) + sdxl_model_util.save_stable_diffusion_checkpoint( + ckpt_file, + text_encoder1, + text_encoder2, + unet, + epoch_no, + global_step, + ckpt_info, + vae, + logit_scale, + sai_metadata, + save_dtype, + ) + + def diffusers_saver(out_dir): + sdxl_model_util.save_diffusers_checkpoint( + out_dir, + text_encoder1, + text_encoder2, + unet, + src_path, + vae, + use_safetensors=use_safetensors, + save_dtype=save_dtype, + ) + + train_util.save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + save_stable_diffusion_format, + use_safetensors, + epoch, + num_train_epochs, + global_step, + sd_saver, + diffusers_saver, + ) + + +def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True): + parser.add_argument( + "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" + ) + parser.add_argument( + "--cache_text_encoder_outputs_to_disk", + action="store_true", + help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", + ) + parser.add_argument( + "--disable_mmap_load_safetensors", + action="store_true", + help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる", + ) + + +def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): + assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" + if args.v_parameterization: + logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") + + if args.clip_skip is not None: + logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") + + # if args.multires_noise_iterations: + # logger.info( + # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります" + # ) + # else: + # if args.noise_offset is None: + # args.noise_offset = DEFAULT_NOISE_OFFSET + # elif args.noise_offset != DEFAULT_NOISE_OFFSET: + # logger.info( + # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています" + # ) + # logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") + + # assert ( + # not hasattr(args, "weighted_captions") or not args.weighted_captions + # ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" + + if supportTextEncoderCaching: + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + args.cache_text_encoder_outputs = True + logger.warning( + "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / " + + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました" + ) + + +def sample_images(*args, **kwargs): + return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) diff --git a/library/slicing_vae.py b/library/slicing_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..ea765342961ab0c55721f5b8a81cbca2ec01feb5 --- /dev/null +++ b/library/slicing_vae.py @@ -0,0 +1,682 @@ +# Modified from Diffusers to reduce VRAM usage + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn + + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block +from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution +from diffusers.models.autoencoder_kl import AutoencoderKLOutput +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) + +def slice_h(x, num_slices): + # slice with pad 1 both sides: to eliminate side effect of padding of conv2d + # Conv2dのpaddingの副作用を排除するために、両側にpad 1しながらHをスライスする + # NCHWでもNHWCでもどちらでも動く + size = (x.shape[2] + num_slices - 1) // num_slices + sliced = [] + for i in range(num_slices): + if i == 0: + sliced.append(x[:, :, : size + 1, :]) + else: + end = size * (i + 1) + 1 + if x.shape[2] - end < 3: # if the last slice is too small, use the rest of the tensor 最後が細すぎるとconv2dできないので全部使う + end = x.shape[2] + sliced.append(x[:, :, size * i - 1 : end, :]) + if end >= x.shape[2]: + break + return sliced + + +def cat_h(sliced): + # padding分を除いて結合する + cat = [] + for i, x in enumerate(sliced): + if i == 0: + cat.append(x[:, :, :-1, :]) + elif i == len(sliced) - 1: + cat.append(x[:, :, 1:, :]) + else: + cat.append(x[:, :, 1:-1, :]) + del x + x = torch.cat(cat, dim=2) + return x + + +def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs): + assert _self.upsample is None and _self.downsample is None + assert _self.norm1.num_groups == _self.norm2.num_groups + assert temb is None + + # make sure norms are on cpu + org_device = input_tensor.device + cpu_device = torch.device("cpu") + _self.norm1.to(cpu_device) + _self.norm2.to(cpu_device) + + # GroupNormがCPUでfp16で動かない対策 + org_dtype = input_tensor.dtype + if org_dtype == torch.float16: + _self.norm1.to(torch.float32) + _self.norm2.to(torch.float32) + + # すべてのテンソルをCPUに移動する + input_tensor = input_tensor.to(cpu_device) + hidden_states = input_tensor + + # どうもこれは結果が異なるようだ…… + # def sliced_norm1(norm, x): + # num_div = 4 if up_block_idx <= 2 else x.shape[1] // norm.num_groups + # sliced_tensor = torch.chunk(x, num_div, dim=1) + # sliced_weight = torch.chunk(norm.weight, num_div, dim=0) + # sliced_bias = torch.chunk(norm.bias, num_div, dim=0) + # logger.info(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape) + # normed_tensor = [] + # for i in range(num_div): + # n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps) + # normed_tensor.append(n) + # del n + # x = torch.cat(normed_tensor, dim=1) + # return num_div, x + + # normを分割すると結果が変わるので、ここだけは分割しない。GPUで計算するとVRAMが足りなくなるので、CPUで計算する。幸いCPUでもそこまで遅くない + if org_dtype == torch.float16: + hidden_states = hidden_states.to(torch.float32) + hidden_states = _self.norm1(hidden_states) # run on cpu + if org_dtype == torch.float16: + hidden_states = hidden_states.to(torch.float16) + + sliced = slice_h(hidden_states, num_slices) + del hidden_states + + for i in range(len(sliced)): + x = sliced[i] + sliced[i] = None + + # 計算する部分だけGPUに移動する、以下同様 + x = x.to(org_device) + x = _self.nonlinearity(x) + x = _self.conv1(x) + x = x.to(cpu_device) + sliced[i] = x + del x + + hidden_states = cat_h(sliced) + del sliced + + if org_dtype == torch.float16: + hidden_states = hidden_states.to(torch.float32) + hidden_states = _self.norm2(hidden_states) # run on cpu + if org_dtype == torch.float16: + hidden_states = hidden_states.to(torch.float16) + + sliced = slice_h(hidden_states, num_slices) + del hidden_states + + for i in range(len(sliced)): + x = sliced[i] + sliced[i] = None + + x = x.to(org_device) + x = _self.nonlinearity(x) + x = _self.dropout(x) + x = _self.conv2(x) + x = x.to(cpu_device) + sliced[i] = x + del x + + hidden_states = cat_h(sliced) + del sliced + + # make shortcut + if _self.conv_shortcut is not None: + sliced = list(torch.chunk(input_tensor, num_slices, dim=2)) # no padding in conv_shortcut パディングがないので普通にスライスする + del input_tensor + + for i in range(len(sliced)): + x = sliced[i] + sliced[i] = None + + x = x.to(org_device) + x = _self.conv_shortcut(x) + x = x.to(cpu_device) + sliced[i] = x + del x + + input_tensor = torch.cat(sliced, dim=2) + del sliced + + output_tensor = (input_tensor + hidden_states) / _self.output_scale_factor + + output_tensor = output_tensor.to(org_device) # 次のレイヤーがGPUで計算する + return output_tensor + + +class SlicingEncoder(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + down_block_types=("DownEncoderBlock2D",), + block_out_channels=(64,), + layers_per_block=2, + norm_num_groups=32, + act_fn="silu", + double_z=True, + num_slices=2, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) + + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-6, + downsample_padding=0, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=None, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + ) + self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) + + # replace forward of ResBlocks + def wrapper(func, module, num_slices): + def forward(*args, **kwargs): + return func(module, num_slices, *args, **kwargs) + + return forward + + self.num_slices = num_slices + div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす + # logger.info(f"initial divisor: {div}") + if div >= 2: + div = int(div) + for resnet in self.mid_block.resnets: + resnet.forward = wrapper(resblock_forward, resnet, div) + # midblock doesn't have downsample + + for i, down_block in enumerate(self.down_blocks[::-1]): + if div >= 2: + div = int(div) + # logger.info(f"down block: {i} divisor: {div}") + for resnet in down_block.resnets: + resnet.forward = wrapper(resblock_forward, resnet, div) + if down_block.downsamplers is not None: + # logger.info("has downsample") + for downsample in down_block.downsamplers: + downsample.forward = wrapper(self.downsample_forward, downsample, div * 2) + div *= 2 + + def forward(self, x): + sample = x + del x + + org_device = sample.device + cpu_device = torch.device("cpu") + + # sample = self.conv_in(sample) + sample = sample.to(cpu_device) + sliced = slice_h(sample, self.num_slices) + del sample + + for i in range(len(sliced)): + x = sliced[i] + sliced[i] = None + + x = x.to(org_device) + x = self.conv_in(x) + x = x.to(cpu_device) + sliced[i] = x + del x + + sample = cat_h(sliced) + del sliced + + sample = sample.to(org_device) + + # down + for down_block in self.down_blocks: + sample = down_block(sample) + + # middle + sample = self.mid_block(sample) + + # post-process + # ここも省メモリ化したいが、恐らくそこまでメモリを食わないので省略 + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + def downsample_forward(self, _self, num_slices, hidden_states): + assert hidden_states.shape[1] == _self.channels + assert _self.use_conv and _self.padding == 0 + logger.info(f"downsample forward {num_slices} {hidden_states.shape}") + + org_device = hidden_states.device + cpu_device = torch.device("cpu") + + hidden_states = hidden_states.to(cpu_device) + pad = (0, 1, 0, 1) + hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0) + + # slice with even number because of stride 2 + # strideが2なので偶数でスライスする + # slice with pad 1 both sides: to eliminate side effect of padding of conv2d + size = (hidden_states.shape[2] + num_slices - 1) // num_slices + size = size + 1 if size % 2 == 1 else size + + sliced = [] + for i in range(num_slices): + if i == 0: + sliced.append(hidden_states[:, :, : size + 1, :]) + else: + end = size * (i + 1) + 1 + if hidden_states.shape[2] - end < 4: # if the last slice is too small, use the rest of the tensor + end = hidden_states.shape[2] + sliced.append(hidden_states[:, :, size * i - 1 : end, :]) + if end >= hidden_states.shape[2]: + break + del hidden_states + + for i in range(len(sliced)): + x = sliced[i] + sliced[i] = None + + x = x.to(org_device) + x = _self.conv(x) + x = x.to(cpu_device) + + # ここだけ雰囲気が違うのはCopilotのせい + if i == 0: + hidden_states = x + else: + hidden_states = torch.cat([hidden_states, x], dim=2) + + hidden_states = hidden_states.to(org_device) + # logger.info(f"downsample forward done {hidden_states.shape}") + return hidden_states + + +class SlicingDecoder(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(64,), + layers_per_block=2, + norm_num_groups=32, + act_fn="silu", + num_slices=2, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + ) + self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=None, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=None, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + # replace forward of ResBlocks + def wrapper(func, module, num_slices): + def forward(*args, **kwargs): + return func(module, num_slices, *args, **kwargs) + + return forward + + self.num_slices = num_slices + div = num_slices / (2 ** (len(self.up_blocks) - 1)) + logger.info(f"initial divisor: {div}") + if div >= 2: + div = int(div) + for resnet in self.mid_block.resnets: + resnet.forward = wrapper(resblock_forward, resnet, div) + # midblock doesn't have upsample + + for i, up_block in enumerate(self.up_blocks): + if div >= 2: + div = int(div) + # logger.info(f"up block: {i} divisor: {div}") + for resnet in up_block.resnets: + resnet.forward = wrapper(resblock_forward, resnet, div) + if up_block.upsamplers is not None: + # logger.info("has upsample") + for upsample in up_block.upsamplers: + upsample.forward = wrapper(self.upsample_forward, upsample, div * 2) + div *= 2 + + def forward(self, z): + sample = z + del z + sample = self.conv_in(sample) + + # middle + sample = self.mid_block(sample) + + # up + for i, up_block in enumerate(self.up_blocks): + sample = up_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + # conv_out with slicing because of VRAM usage + # conv_outはとてもVRAM使うのでスライスして対応 + org_device = sample.device + cpu_device = torch.device("cpu") + sample = sample.to(cpu_device) + + sliced = slice_h(sample, self.num_slices) + del sample + for i in range(len(sliced)): + x = sliced[i] + sliced[i] = None + + x = x.to(org_device) + x = self.conv_out(x) + x = x.to(cpu_device) + sliced[i] = x + sample = cat_h(sliced) + del sliced + + sample = sample.to(org_device) + return sample + + def upsample_forward(self, _self, num_slices, hidden_states, output_size=None): + assert hidden_states.shape[1] == _self.channels + assert _self.use_conv_transpose == False and _self.use_conv + + org_dtype = hidden_states.dtype + org_device = hidden_states.device + cpu_device = torch.device("cpu") + + hidden_states = hidden_states.to(cpu_device) + sliced = slice_h(hidden_states, num_slices) + del hidden_states + + for i in range(len(sliced)): + x = sliced[i] + sliced[i] = None + + x = x.to(org_device) + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch + # https://github.com/pytorch/pytorch/issues/86679 + # PyTorch 2で直らないかね…… + if org_dtype == torch.bfloat16: + x = x.to(torch.float32) + + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + + if org_dtype == torch.bfloat16: + x = x.to(org_dtype) + + x = _self.conv(x) + + # upsampleされてるのでpadは2になる + if i == 0: + x = x[:, :, :-2, :] + elif i == num_slices - 1: + x = x[:, :, 2:, :] + else: + x = x[:, :, 2:-2, :] + + x = x.to(cpu_device) + sliced[i] = x + del x + + hidden_states = torch.cat(sliced, dim=2) + # logger.info(f"us hidden_states {hidden_states.shape}") + del sliced + + hidden_states = hidden_states.to(org_device) + return hidden_states + + +class SlicingAutoencoderKL(ModelMixin, ConfigMixin): + r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma + and Max Welling. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(64,)`): Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): TODO + """ + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + sample_size: int = 32, + num_slices: int = 16, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = SlicingEncoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + num_slices=num_slices, + ) + + # pass init params to Decoder + self.decoder = SlicingDecoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + num_slices=num_slices, + ) + + self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) + self.use_slicing = False + + def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + z = self.post_quant_conv(z) + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + # これはバッチ方向のスライシング 紛らわしい + def enable_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/library/strategy_base.py b/library/strategy_base.py new file mode 100644 index 0000000000000000000000000000000000000000..27272727bc789057e0a5cf8a9dc2308179e3942f --- /dev/null +++ b/library/strategy_base.py @@ -0,0 +1,570 @@ +# base class for platform strategies. this file defines the interface for strategies + +import os +import re +from typing import Any, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection + + +# TODO remove circular import by moving ImageInfo to a separate file +# from library.train_util import ImageInfo + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class TokenizeStrategy: + _strategy = None # strategy instance: actual strategy class + + _re_attention = re.compile( + r"""\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, + ) + + @classmethod + def set_strategy(cls, strategy): + #if cls._strategy is not None: + # raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["TokenizeStrategy"]: + return cls._strategy + + def _load_tokenizer( + self, model_class: Any, model_id: str, subfolder: Optional[str] = None, tokenizer_cache_dir: Optional[str] = None + ) -> Any: + tokenizer = None + if tokenizer_cache_dir: + local_tokenizer_path = os.path.join(tokenizer_cache_dir, model_id.replace("/", "_")) + if os.path.exists(local_tokenizer_path): + logger.info(f"load tokenizer from cache: {local_tokenizer_path}") + tokenizer = model_class.from_pretrained(local_tokenizer_path) # same for v1 and v2 + + if tokenizer is None: + tokenizer = model_class.from_pretrained(model_id, subfolder=subfolder) + + if tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): + logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") + tokenizer.save_pretrained(local_tokenizer_path) + + return tokenizer + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + raise NotImplementedError + + def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + returns: [tokens1, tokens2, ...], [weights1, weights2, ...] + """ + raise NotImplementedError + + def _get_weighted_input_ids( + self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + max_length includes starting and ending tokens. + """ + + def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in TokenizeStrategy._re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + def get_prompts_with_weights(text: str, max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. max_length does not include starting and ending token. + + No padding, starting or ending token is included. + """ + truncated = False + + texts_and_weights = parse_prompt_attention(text) + tokens = [] + weights = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = tokenizer(word).input_ids[1:-1] + tokens += token + # copy the weight by length of token + weights += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(tokens) > max_length: + truncated = True + break + # truncate + if len(tokens) > max_length: + truncated = True + tokens = tokens[:max_length] + weights = weights[:max_length] + if truncated: + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + tokens = [bos] + tokens + [eos] + [pad] * (max_length - 2 - len(tokens)) + weights = [1.0] + weights + [1.0] * (max_length - 1 - len(weights)) + return tokens, weights + + if max_length is None: + max_length = tokenizer.model_max_length + + tokens, weights = get_prompts_with_weights(text, max_length - 2) + tokens, weights = pad_tokens_and_weights( + tokens, weights, max_length, tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id + ) + return torch.tensor(tokens).unsqueeze(0), torch.tensor(weights).unsqueeze(0) + + def _get_input_ids( + self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None, weighted: bool = False + ) -> torch.Tensor: + """ + for SD1.5/2.0/SDXL + TODO support batch input + """ + if max_length is None: + max_length = tokenizer.model_max_length - 2 + + if weighted: + input_ids, weights = self._get_weighted_input_ids(tokenizer, text, max_length) + else: + input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids + + if max_length > tokenizer.model_max_length: + input_ids = input_ids.squeeze(0) + iids_list = [] + if tokenizer.pad_token_id == tokenizer.eos_token_id: + # v1 + # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する + # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に + for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): # (1, 152, 75) + ids_chunk = ( + input_ids[0].unsqueeze(0), + input_ids[i : i + tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) + ids_chunk = torch.cat(ids_chunk) + iids_list.append(ids_chunk) + else: + # v2 or SDXL + # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する + for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): + ids_chunk = ( + input_ids[0].unsqueeze(0), # BOS + input_ids[i : i + tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) # PAD or EOS + ids_chunk = torch.cat(ids_chunk) + + # 末尾が または の場合は、何もしなくてよい + # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) + if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id: + ids_chunk[-1] = tokenizer.eos_token_id + # 先頭が ... の場合は ... に変える + if ids_chunk[1] == tokenizer.pad_token_id: + ids_chunk[1] = tokenizer.eos_token_id + + iids_list.append(ids_chunk) + + input_ids = torch.stack(iids_list) # 3,77 + + if weighted: + weights = weights.squeeze(0) + new_weights = torch.ones(input_ids.shape) + for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): + b = i // (tokenizer.model_max_length - 2) + new_weights[b, 1 : 1 + tokenizer.model_max_length - 2] = weights[i : i + tokenizer.model_max_length - 2] + weights = new_weights + + if weighted: + return input_ids, weights + return input_ids + + +class TextEncodingStrategy: + _strategy = None # strategy instance: actual strategy class + + @classmethod + def set_strategy(cls, strategy): + #if cls._strategy is not None: + # raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["TextEncodingStrategy"]: + return cls._strategy + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + """ + Encode tokens into embeddings and outputs. + :param tokens: list of token tensors for each TextModel + :return: list of output embeddings for each architecture + """ + raise NotImplementedError + + def encode_tokens_with_weights( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor] + ) -> List[torch.Tensor]: + """ + Encode tokens into embeddings and outputs. + :param tokens: list of token tensors for each TextModel + :param weights: list of weight tensors for each TextModel + :return: list of output embeddings for each architecture + """ + raise NotImplementedError + + +class TextEncoderOutputsCachingStrategy: + _strategy = None # strategy instance: actual strategy class + + def __init__( + self, + cache_to_disk: bool, + batch_size: Optional[int], + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + is_weighted: bool = False, + ) -> None: + self._cache_to_disk = cache_to_disk + self._batch_size = batch_size + self.skip_disk_cache_validity_check = skip_disk_cache_validity_check + self._is_partial = is_partial + self._is_weighted = is_weighted + + @classmethod + def set_strategy(cls, strategy): + #if cls._strategy is not None: + # raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]: + return cls._strategy + + @property + def cache_to_disk(self): + return self._cache_to_disk + + @property + def batch_size(self): + return self._batch_size + + @property + def is_partial(self): + return self._is_partial + + @property + def is_weighted(self): + return self._is_weighted + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + raise NotImplementedError + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + raise NotImplementedError + + def is_disk_cached_outputs_expected(self, npz_path: str) -> bool: + raise NotImplementedError + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List + ): + raise NotImplementedError + + +class LatentsCachingStrategy: + # TODO commonize utillity functions to this class, such as npz handling etc. + + _strategy = None # strategy instance: actual strategy class + + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + self._cache_to_disk = cache_to_disk + self._batch_size = batch_size + self.skip_disk_cache_validity_check = skip_disk_cache_validity_check + + @classmethod + def set_strategy(cls, strategy): + #if cls._strategy is not None: + # raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["LatentsCachingStrategy"]: + return cls._strategy + + @property + def cache_to_disk(self): + return self._cache_to_disk + + @property + def batch_size(self): + return self._batch_size + + @property + def cache_suffix(self): + raise NotImplementedError + + def get_image_size_from_disk_cache_path(self, absolute_path: str, npz_path: str) -> Tuple[Optional[int], Optional[int]]: + w, h = os.path.splitext(npz_path)[0].split("_")[-2].split("x") + return int(w), int(h) + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + raise NotImplementedError + + def is_disk_cached_latents_expected( + self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool + ) -> bool: + raise NotImplementedError + + def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + raise NotImplementedError + + def _default_is_disk_cached_latents_expected( + self, + latents_stride: int, + bucket_reso: Tuple[int, int], + npz_path: str, + flip_aug: bool, + alpha_mask: bool, + multi_resolution: bool = False, + ): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) + + # e.g. "_32x64", HxW + key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else "" + + try: + npz = np.load(npz_path) + if "latents" + key_reso_suffix not in npz: + return False + if flip_aug and "latents_flipped" + key_reso_suffix not in npz: + return False + if alpha_mask and "alpha_mask" + key_reso_suffix not in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + # TODO remove circular dependency for ImageInfo + def _default_cache_batch_latents( + self, + encode_by_vae, + vae_device, + vae_dtype, + image_infos: List, + flip_aug: bool, + alpha_mask: bool, + random_crop: bool, + multi_resolution: bool = False, + ): + """ + Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common. + """ + from . import train_util # import here to avoid circular import + + img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching( + image_infos, alpha_mask, random_crop + ) + img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype) + + with torch.no_grad(): + latents_tensors = encode_by_vae(img_tensor).to("cpu") + if flip_aug: + img_tensor = torch.flip(img_tensor, dims=[3]) + with torch.no_grad(): + flipped_latents = encode_by_vae(img_tensor).to("cpu") + else: + flipped_latents = [None] * len(latents_tensors) + + # for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks): + for i in range(len(image_infos)): + info = image_infos[i] + latents = latents_tensors[i] + flipped_latent = flipped_latents[i] + alpha_mask = alpha_masks[i] + original_size = original_sizes[i] + crop_ltrb = crop_ltrbs[i] + + latents_size = latents.shape[1:3] # H, W + key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW + + if self.cache_to_disk: + self.save_latents_to_disk( + info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask, key_reso_suffix + ) + else: + info.latents_original_size = original_size + info.latents_crop_ltrb = crop_ltrb + info.latents = latents + if flip_aug: + info.latents_flipped = flipped_latent + info.alpha_mask = alpha_mask + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + """ + for SD/SDXL + """ + return self._default_load_latents_from_disk(None, npz_path, bucket_reso) + + def _default_load_latents_from_disk( + self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + if latents_stride is None: + key_reso_suffix = "" + else: + latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) + key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW + + npz = np.load(npz_path) + if "latents" + key_reso_suffix not in npz: + raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}") + + latents = npz["latents" + key_reso_suffix] + original_size = npz["original_size" + key_reso_suffix].tolist() + crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist() + flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None + alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None + return latents, original_size, crop_ltrb, flipped_latents, alpha_mask + + def save_latents_to_disk( + self, + npz_path, + latents_tensor, + original_size, + crop_ltrb, + flipped_latents_tensor=None, + alpha_mask=None, + key_reso_suffix="", + ): + kwargs = {} + + if os.path.exists(npz_path): + # load existing npz and update it + npz = np.load(npz_path) + for key in npz.files: + kwargs[key] = npz[key] + + kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy() + kwargs["original_size" + key_reso_suffix] = np.array(original_size) + kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb) + if flipped_latents_tensor is not None: + kwargs["latents_flipped" + key_reso_suffix] = flipped_latents_tensor.float().cpu().numpy() + if alpha_mask is not None: + kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy() + np.savez(npz_path, **kwargs) \ No newline at end of file diff --git a/library/strategy_flux.py b/library/strategy_flux.py new file mode 100644 index 0000000000000000000000000000000000000000..7e0340af15928e0e35472bfe4611636cad0c0187 --- /dev/null +++ b/library/strategy_flux.py @@ -0,0 +1,270 @@ +import os +import glob +from typing import Any, List, Optional, Tuple, Union +import torch +import numpy as np +from transformers import CLIPTokenizer, T5TokenizerFast + +from . import train_util +from .flux_utils import get_t5xxl_actual_dtype +from .strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14" +T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl" + + +class FluxTokenizeStrategy(TokenizeStrategy): + def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None: + self.t5xxl_max_length = t5xxl_max_length + self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + + l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt") + + t5_attn_mask = t5_tokens["attention_mask"] + l_tokens = l_tokens["input_ids"] + t5_tokens = t5_tokens["input_ids"] + + return [l_tokens, t5_tokens, t5_attn_mask] + + +class FluxTextEncodingStrategy(TextEncodingStrategy): + def __init__(self, apply_t5_attn_mask: Optional[bool] = None) -> None: + """ + Args: + apply_t5_attn_mask: Default value for apply_t5_attn_mask. + """ + self.apply_t5_attn_mask = apply_t5_attn_mask + + def encode_tokens( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens: List[torch.Tensor], + apply_t5_attn_mask: Optional[bool] = None, + ) -> List[torch.Tensor]: + # supports single model inference + + if apply_t5_attn_mask is None: + apply_t5_attn_mask = self.apply_t5_attn_mask + + clip_l, t5xxl = models if len(models) == 2 else (models[0], None) + l_tokens, t5_tokens = tokens[:2] + t5_attn_mask = tokens[2] if len(tokens) > 2 else None + + # clip_l is None when using T5 only + if clip_l is not None and l_tokens is not None: + l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"] + else: + l_pooled = None + + # t5xxl is None when using CLIP only + if t5xxl is not None and t5_tokens is not None: + # t5_out is [b, max length, 4096] + attention_mask = None if not apply_t5_attn_mask else t5_attn_mask.to(t5xxl.device) + t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), attention_mask, return_dict=False, output_hidden_states=True) + # if zero_pad_t5_output: + # t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1) + txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device) + else: + t5_out = None + txt_ids = None + t5_attn_mask = None # caption may be dropped/shuffled, so t5_attn_mask should not be used to make sure the mask is same as the cached one + + return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer + + +class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz" + + def __init__( + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + apply_t5_attn_mask: bool = False, + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + self.apply_t5_attn_mask = apply_t5_attn_mask + + self.warn_fp8_weights = False + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + + def is_disk_cached_outputs_expected(self, npz_path: str): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(npz_path) + if "l_pooled" not in npz: + return False + if "t5_out" not in npz: + return False + if "txt_ids" not in npz: + return False + if "t5_attn_mask" not in npz: + return False + if "apply_t5_attn_mask" not in npz: + return False + npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"] + if npz_apply_t5_attn_mask != self.apply_t5_attn_mask: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + data = np.load(npz_path) + l_pooled = data["l_pooled"] + t5_out = data["t5_out"] + txt_ids = data["txt_ids"] + t5_attn_mask = data["t5_attn_mask"] + # apply_t5_attn_mask should be same as self.apply_t5_attn_mask + return [l_pooled, t5_out, txt_ids, t5_attn_mask] + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List + ): + if not self.warn_fp8_weights: + if get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn: + logger.warning( + "T5 model is using fp8 weights for caching. This may affect the quality of the cached outputs." + ) + self.warn_fp8_weights = True + + flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy + captions = [info.caption for info in infos] + + tokens_and_masks = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + # attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True + l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks) + + if l_pooled.dtype == torch.bfloat16: + l_pooled = l_pooled.float() + if t5_out.dtype == torch.bfloat16: + t5_out = t5_out.float() + if txt_ids.dtype == torch.bfloat16: + txt_ids = txt_ids.float() + + l_pooled = l_pooled.cpu().numpy() + t5_out = t5_out.cpu().numpy() + txt_ids = txt_ids.cpu().numpy() + t5_attn_mask = tokens_and_masks[2].cpu().numpy() + + for i, info in enumerate(infos): + l_pooled_i = l_pooled[i] + t5_out_i = t5_out[i] + txt_ids_i = txt_ids[i] + t5_attn_mask_i = t5_attn_mask[i] + apply_t5_attn_mask_i = self.apply_t5_attn_mask + + if self.cache_to_disk: + np.savez( + info.text_encoder_outputs_npz, + l_pooled=l_pooled_i, + t5_out=t5_out_i, + txt_ids=txt_ids_i, + t5_attn_mask=t5_attn_mask_i, + apply_t5_attn_mask=apply_t5_attn_mask_i, + ) + else: + info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i) + + +class FluxLatentsCachingStrategy(LatentsCachingStrategy): + FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz" + + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + + @property + def cache_suffix(self) -> str: + return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + return ( + os.path.splitext(absolute_path)[0] + + f"_{image_size[0]:04d}x{image_size[1]:04d}" + + FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX + ) + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu") + vae_device = vae.device + vae_dtype = vae.dtype + + self._default_cache_batch_latents( + encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True + ) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) + + +if __name__ == "__main__": + # test code for FluxTokenizeStrategy + # tokenizer = sd3_models.SD3Tokenizer() + strategy = FluxTokenizeStrategy(256) + text = "hello world" + + l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) + # print(l_tokens.shape) + print(l_tokens) + print(g_tokens) + print(t5_tokens) + + texts = ["hello world", "the quick brown fox jumps over the lazy dog"] + l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + t5_tokens_2 = strategy.t5xxl( + texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + print(l_tokens_2) + print(g_tokens_2) + print(t5_tokens_2) + + # compare + print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0])) + print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0])) + print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0])) + + text = ",".join(["hello world! this is long text"] * 50) + l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) + print(l_tokens) + print(g_tokens) + print(t5_tokens) + + print(f"model max length l: {strategy.clip_l.model_max_length}") + print(f"model max length g: {strategy.clip_g.model_max_length}") + print(f"model max length t5: {strategy.t5xxl.model_max_length}") \ No newline at end of file diff --git a/library/strategy_sd.py b/library/strategy_sd.py new file mode 100644 index 0000000000000000000000000000000000000000..c883644070ed1466406ad553dda02d88140693b8 --- /dev/null +++ b/library/strategy_sd.py @@ -0,0 +1,171 @@ +import glob +import os +from typing import Any, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTokenizer +from . import train_util +from .strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +TOKENIZER_ID = "openai/clip-vit-large-patch14" +V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ + + +class SdTokenizeStrategy(TokenizeStrategy): + def __init__(self, v2: bool, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None: + """ + max_length does not include and (None, 75, 150, 225) + """ + logger.info(f"Using {'v2' if v2 else 'v1'} tokenizer") + if v2: + self.tokenizer = self._load_tokenizer( + CLIPTokenizer, V2_STABLE_DIFFUSION_ID, subfolder="tokenizer", tokenizer_cache_dir=tokenizer_cache_dir + ) + else: + self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + + if max_length is None: + self.max_length = self.tokenizer.model_max_length + else: + self.max_length = max_length + 2 + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)] + + def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]: + text = [text] if isinstance(text, str) else text + tokens_list = [] + weights_list = [] + for t in text: + tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True) + tokens_list.append(tokens) + weights_list.append(weights) + return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)] + + +class SdTextEncodingStrategy(TextEncodingStrategy): + def __init__(self, clip_skip: Optional[int] = None) -> None: + self.clip_skip = clip_skip + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + text_encoder = models[0] + tokens = tokens[0] + sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy + + # tokens: b,n,77 + b_size = tokens.size()[0] + max_token_length = tokens.size()[1] * tokens.size()[2] + model_max_length = sd_tokenize_strategy.tokenizer.model_max_length + tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77 + + tokens = tokens.to(text_encoder.device) + + if self.clip_skip is None: + encoder_hidden_states = text_encoder(tokens)[0] + else: + enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip] + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) + + # bs*3, 77, 768 or 1024 + encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) + + if max_token_length != model_max_length: + v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id + if not v1: + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, model_max_length): + chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # の後から 最後の前まで + if i > 0: + for j in range(len(chunk)): + if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token: + # 空、つまり ...のパターン + chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか + encoder_hidden_states = torch.cat(states_list, dim=1) + else: + # v1: ... の三連を ... へ戻す + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, model_max_length): + states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # + encoder_hidden_states = torch.cat(states_list, dim=1) + + return [encoder_hidden_states] + + def encode_tokens_with_weights( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens_list: List[torch.Tensor], + weights_list: List[torch.Tensor], + ) -> List[torch.Tensor]: + encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0] + + weights = weights_list[0].to(encoder_hidden_states.device) + + # apply weights + if weights.shape[1] == 1: # no max_token_length + # weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768) + encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2) + else: + # weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768) + for i in range(weights.shape[1]): + encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[ + :, i, 1:-1 + ].unsqueeze(-1) + + return [encoder_hidden_states] + + +class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): + # sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix. + # and we keep the old npz for the backward compatibility. + + SD_OLD_LATENTS_NPZ_SUFFIX = ".npz" + SD_LATENTS_NPZ_SUFFIX = "_sd.npz" + SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz" + + def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + self.sd = sd + self.suffix = ( + SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX + ) + + @property + def cache_suffix(self) -> str: + return self.suffix + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + # support old .npz + old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX + if os.path.exists(old_npz_file): + return old_npz_file + return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample() + vae_device = vae.device + vae_dtype = vae.dtype + + self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py new file mode 100644 index 0000000000000000000000000000000000000000..5da4fa7e51a15c84ff9bcbf299a3832f6864cbe0 --- /dev/null +++ b/library/strategy_sd3.py @@ -0,0 +1,456 @@ +import os +import glob +from typing import Any, List, Optional, Tuple, Union +import torch +import numpy as np +from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel + +from . import train_util +from .strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14" +CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" +T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl" + + +class Sd3TokenizeStrategy(TokenizeStrategy): + def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None: + self.t5xxl_max_length = t5xxl_max_length + self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.clip_g = self._load_tokenizer(CLIPTokenizer, CLIP_G_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) + self.clip_g.pad_token_id = 0 # use 0 as pad token for clip_g + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + + l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt") + + l_attn_mask = l_tokens["attention_mask"] + g_attn_mask = g_tokens["attention_mask"] + t5_attn_mask = t5_tokens["attention_mask"] + l_tokens = l_tokens["input_ids"] + g_tokens = g_tokens["input_ids"] + t5_tokens = t5_tokens["input_ids"] + + return [l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask] + + +class Sd3TextEncodingStrategy(TextEncodingStrategy): + def __init__( + self, + apply_lg_attn_mask: Optional[bool] = None, + apply_t5_attn_mask: Optional[bool] = None, + l_dropout_rate: float = 0.0, + g_dropout_rate: float = 0.0, + t5_dropout_rate: float = 0.0, + ) -> None: + """ + Args: + apply_t5_attn_mask: Default value for apply_t5_attn_mask. + """ + self.apply_lg_attn_mask = apply_lg_attn_mask + self.apply_t5_attn_mask = apply_t5_attn_mask + self.l_dropout_rate = l_dropout_rate + self.g_dropout_rate = g_dropout_rate + self.t5_dropout_rate = t5_dropout_rate + + def encode_tokens( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens: List[torch.Tensor], + apply_lg_attn_mask: Optional[bool] = False, + apply_t5_attn_mask: Optional[bool] = False, + enable_dropout: bool = True, + ) -> List[torch.Tensor]: + """ + returned embeddings are not masked + """ + clip_l, clip_g, t5xxl = models + clip_l: Optional[CLIPTextModel] + clip_g: Optional[CLIPTextModelWithProjection] + t5xxl: Optional[T5EncoderModel] + + if apply_lg_attn_mask is None: + apply_lg_attn_mask = self.apply_lg_attn_mask + if apply_t5_attn_mask is None: + apply_t5_attn_mask = self.apply_t5_attn_mask + + l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask = tokens + + # dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings + + if l_tokens is None or clip_l is None: + assert g_tokens is None, "g_tokens must be None if l_tokens is None" + lg_out = None + lg_pooled = None + l_attn_mask = None + g_attn_mask = None + else: + assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" + + # drop some members of the batch: we do not call clip_l and clip_g for dropped members + batch_size, l_seq_len = l_tokens.shape + g_seq_len = g_tokens.shape[1] + + non_drop_l_indices = [] + non_drop_g_indices = [] + for i in range(l_tokens.shape[0]): + drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate) + drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate) + if not drop_l: + non_drop_l_indices.append(i) + if not drop_g: + non_drop_g_indices.append(i) + + # filter out dropped members + if len(non_drop_l_indices) > 0 and len(non_drop_l_indices) < batch_size: + l_tokens = l_tokens[non_drop_l_indices] + l_attn_mask = l_attn_mask[non_drop_l_indices] + if len(non_drop_g_indices) > 0 and len(non_drop_g_indices) < batch_size: + g_tokens = g_tokens[non_drop_g_indices] + g_attn_mask = g_attn_mask[non_drop_g_indices] + + # call clip_l for non-dropped members + if len(non_drop_l_indices) > 0: + nd_l_attn_mask = l_attn_mask.to(clip_l.device) + prompt_embeds = clip_l( + l_tokens.to(clip_l.device), nd_l_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True + ) + nd_l_pooled = prompt_embeds[0] + nd_l_out = prompt_embeds.hidden_states[-2] + if len(non_drop_g_indices) > 0: + nd_g_attn_mask = g_attn_mask.to(clip_g.device) + prompt_embeds = clip_g( + g_tokens.to(clip_g.device), nd_g_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True + ) + nd_g_pooled = prompt_embeds[0] + nd_g_out = prompt_embeds.hidden_states[-2] + + # fill in the dropped members + if len(non_drop_l_indices) == batch_size: + l_pooled = nd_l_pooled + l_out = nd_l_out + else: + # model output is always float32 because of the models are wrapped with Accelerator + l_pooled = torch.zeros((batch_size, 768), device=clip_l.device, dtype=torch.float32) + l_out = torch.zeros((batch_size, l_seq_len, 768), device=clip_l.device, dtype=torch.float32) + l_attn_mask = torch.zeros((batch_size, l_seq_len), device=clip_l.device, dtype=l_attn_mask.dtype) + if len(non_drop_l_indices) > 0: + l_pooled[non_drop_l_indices] = nd_l_pooled + l_out[non_drop_l_indices] = nd_l_out + l_attn_mask[non_drop_l_indices] = nd_l_attn_mask + + if len(non_drop_g_indices) == batch_size: + g_pooled = nd_g_pooled + g_out = nd_g_out + else: + g_pooled = torch.zeros((batch_size, 1280), device=clip_g.device, dtype=torch.float32) + g_out = torch.zeros((batch_size, g_seq_len, 1280), device=clip_g.device, dtype=torch.float32) + g_attn_mask = torch.zeros((batch_size, g_seq_len), device=clip_g.device, dtype=g_attn_mask.dtype) + if len(non_drop_g_indices) > 0: + g_pooled[non_drop_g_indices] = nd_g_pooled + g_out[non_drop_g_indices] = nd_g_out + g_attn_mask[non_drop_g_indices] = nd_g_attn_mask + + lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) + lg_out = torch.cat([l_out, g_out], dim=-1) + + if t5xxl is None or t5_tokens is None: + t5_out = None + t5_attn_mask = None + else: + # drop some members of the batch: we do not call t5xxl for dropped members + batch_size, t5_seq_len = t5_tokens.shape + non_drop_t5_indices = [] + for i in range(t5_tokens.shape[0]): + drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate) + if not drop_t5: + non_drop_t5_indices.append(i) + + # filter out dropped members + if len(non_drop_t5_indices) > 0 and len(non_drop_t5_indices) < batch_size: + t5_tokens = t5_tokens[non_drop_t5_indices] + t5_attn_mask = t5_attn_mask[non_drop_t5_indices] + + # call t5xxl for non-dropped members + if len(non_drop_t5_indices) > 0: + nd_t5_attn_mask = t5_attn_mask.to(t5xxl.device) + nd_t5_out, _ = t5xxl( + t5_tokens.to(t5xxl.device), + nd_t5_attn_mask if apply_t5_attn_mask else None, + return_dict=False, + output_hidden_states=True, + ) + + # fill in the dropped members + if len(non_drop_t5_indices) == batch_size: + t5_out = nd_t5_out + else: + t5_out = torch.zeros((batch_size, t5_seq_len, 4096), device=t5xxl.device, dtype=torch.float32) + t5_attn_mask = torch.zeros((batch_size, t5_seq_len), device=t5xxl.device, dtype=t5_attn_mask.dtype) + if len(non_drop_t5_indices) > 0: + t5_out[non_drop_t5_indices] = nd_t5_out + t5_attn_mask[non_drop_t5_indices] = nd_t5_attn_mask + + # masks are used for attention masking in transformer + return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] + + def drop_cached_text_encoder_outputs( + self, + lg_out: torch.Tensor, + t5_out: torch.Tensor, + lg_pooled: torch.Tensor, + l_attn_mask: torch.Tensor, + g_attn_mask: torch.Tensor, + t5_attn_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # dropout: if enable_dropout is True, dropout is not applied. dropout means zeroing out embeddings + if lg_out is not None: + for i in range(lg_out.shape[0]): + drop_l = self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate + if drop_l: + lg_out[i, :, :768] = torch.zeros_like(lg_out[i, :, :768]) + lg_pooled[i, :768] = torch.zeros_like(lg_pooled[i, :768]) + if l_attn_mask is not None: + l_attn_mask[i] = torch.zeros_like(l_attn_mask[i]) + drop_g = self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate + if drop_g: + lg_out[i, :, 768:] = torch.zeros_like(lg_out[i, :, 768:]) + lg_pooled[i, 768:] = torch.zeros_like(lg_pooled[i, 768:]) + if g_attn_mask is not None: + g_attn_mask[i] = torch.zeros_like(g_attn_mask[i]) + + if t5_out is not None: + for i in range(t5_out.shape[0]): + drop_t5 = self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate + if drop_t5: + t5_out[i] = torch.zeros_like(t5_out[i]) + if t5_attn_mask is not None: + t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i]) + + return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] + + def concat_encodings( + self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + if t5_out is None: + t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype) + return torch.cat([lg_out, t5_out], dim=-2), lg_pooled + + +class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz" + + def __init__( + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + apply_lg_attn_mask: bool = False, + apply_t5_attn_mask: bool = False, + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + self.apply_lg_attn_mask = apply_lg_attn_mask + self.apply_t5_attn_mask = apply_t5_attn_mask + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + + def is_disk_cached_outputs_expected(self, npz_path: str): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(npz_path) + if "lg_out" not in npz: + return False + if "lg_pooled" not in npz: + return False + if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used + return False + if "apply_lg_attn_mask" not in npz: + return False + if "t5_out" not in npz: + return False + if "t5_attn_mask" not in npz: + return False + npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"] + if npz_apply_lg_attn_mask != self.apply_lg_attn_mask: + return False + if "apply_t5_attn_mask" not in npz: + return False + npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"] + if npz_apply_t5_attn_mask != self.apply_t5_attn_mask: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + data = np.load(npz_path) + lg_out = data["lg_out"] + lg_pooled = data["lg_pooled"] + t5_out = data["t5_out"] + + l_attn_mask = data["clip_l_attn_mask"] + g_attn_mask = data["clip_g_attn_mask"] + t5_attn_mask = data["t5_attn_mask"] + + # apply_t5_attn_mask and apply_lg_attn_mask are same as self.apply_t5_attn_mask and self.apply_lg_attn_mask + return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List + ): + sd3_text_encoding_strategy: Sd3TextEncodingStrategy = text_encoding_strategy + captions = [info.caption for info in infos] + + tokens_and_masks = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + # always disable dropout during caching + lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens( + tokenize_strategy, + models, + tokens_and_masks, + apply_lg_attn_mask=self.apply_lg_attn_mask, + apply_t5_attn_mask=self.apply_t5_attn_mask, + enable_dropout=False, + ) + + if lg_out.dtype == torch.bfloat16: + lg_out = lg_out.float() + if lg_pooled.dtype == torch.bfloat16: + lg_pooled = lg_pooled.float() + if t5_out.dtype == torch.bfloat16: + t5_out = t5_out.float() + + lg_out = lg_out.cpu().numpy() + lg_pooled = lg_pooled.cpu().numpy() + t5_out = t5_out.cpu().numpy() + + l_attn_mask = tokens_and_masks[3].cpu().numpy() + g_attn_mask = tokens_and_masks[4].cpu().numpy() + t5_attn_mask = tokens_and_masks[5].cpu().numpy() + + for i, info in enumerate(infos): + lg_out_i = lg_out[i] + t5_out_i = t5_out[i] + lg_pooled_i = lg_pooled[i] + l_attn_mask_i = l_attn_mask[i] + g_attn_mask_i = g_attn_mask[i] + t5_attn_mask_i = t5_attn_mask[i] + apply_lg_attn_mask = self.apply_lg_attn_mask + apply_t5_attn_mask = self.apply_t5_attn_mask + + if self.cache_to_disk: + np.savez( + info.text_encoder_outputs_npz, + lg_out=lg_out_i, + lg_pooled=lg_pooled_i, + t5_out=t5_out_i, + clip_l_attn_mask=l_attn_mask_i, + clip_g_attn_mask=g_attn_mask_i, + t5_attn_mask=t5_attn_mask_i, + apply_lg_attn_mask=apply_lg_attn_mask, + apply_t5_attn_mask=apply_t5_attn_mask, + ) + else: + # it's fine that attn mask is not None. it's overwritten before calling the model if necessary + info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i) + + +class Sd3LatentsCachingStrategy(LatentsCachingStrategy): + SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" + + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + + @property + def cache_suffix(self) -> str: + return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + return ( + os.path.splitext(absolute_path)[0] + + f"_{image_size[0]:04d}x{image_size[1]:04d}" + + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX + ) + + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution + + # TODO remove circular dependency for ImageInfo + def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): + encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu") + vae_device = vae.device + vae_dtype = vae.dtype + + self._default_cache_batch_latents( + encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True + ) + + if not train_util.HIGH_VRAM: + train_util.clean_memory_on_device(vae.device) + + +if __name__ == "__main__": + # test code for Sd3TokenizeStrategy + # tokenizer = sd3_models.SD3Tokenizer() + strategy = Sd3TokenizeStrategy(256) + text = "hello world" + + l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) + # print(l_tokens.shape) + print(l_tokens) + print(g_tokens) + print(t5_tokens) + + texts = ["hello world", "the quick brown fox jumps over the lazy dog"] + l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt") + t5_tokens_2 = strategy.t5xxl( + texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + print(l_tokens_2) + print(g_tokens_2) + print(t5_tokens_2) + + # compare + print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0])) + print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0])) + print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0])) + + text = ",".join(["hello world! this is long text"] * 50) + l_tokens, g_tokens, t5_tokens = strategy.tokenize(text) + print(l_tokens) + print(g_tokens) + print(t5_tokens) + + print(f"model max length l: {strategy.clip_l.model_max_length}") + print(f"model max length g: {strategy.clip_g.model_max_length}") + print(f"model max length t5: {strategy.t5xxl.model_max_length}") diff --git a/library/strategy_sdxl.py b/library/strategy_sdxl.py new file mode 100644 index 0000000000000000000000000000000000000000..5d5bc957408fed14b84255b47e54ccdd9a5979a2 --- /dev/null +++ b/library/strategy_sdxl.py @@ -0,0 +1,306 @@ +import os +from typing import Any, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection +from .strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy + + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +TOKENIZER1_PATH = "openai/clip-vit-large-patch14" +TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + + +class SdxlTokenizeStrategy(TokenizeStrategy): + def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None: + self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir) + self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir) + self.tokenizer2.pad_token_id = 0 # use 0 as pad token for tokenizer2 + + if max_length is None: + self.max_length = self.tokenizer1.model_max_length + else: + self.max_length = max_length + 2 + + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: + text = [text] if isinstance(text, str) else text + return ( + torch.stack([self._get_input_ids(self.tokenizer1, t, self.max_length) for t in text], dim=0), + torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0), + ) + + def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]: + text = [text] if isinstance(text, str) else text + tokens1_list, tokens2_list = [], [] + weights1_list, weights2_list = [], [] + for t in text: + tokens1, weights1 = self._get_input_ids(self.tokenizer1, t, self.max_length, weighted=True) + tokens2, weights2 = self._get_input_ids(self.tokenizer2, t, self.max_length, weighted=True) + tokens1_list.append(tokens1) + tokens2_list.append(tokens2) + weights1_list.append(weights1) + weights2_list.append(weights2) + return [torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)], [ + torch.stack(weights1_list, dim=0), + torch.stack(weights2_list, dim=0), + ] + + +class SdxlTextEncodingStrategy(TextEncodingStrategy): + def __init__(self) -> None: + pass + + def _pool_workaround( + self, text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int + ): + r""" + workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output + instead of the hidden states for the EOS token + If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output + + Original code from CLIP's pooling function: + + \# text_embeds.shape = [batch_size, sequence_length, transformer.width] + \# take features from the eot embedding (eot_token is the highest number in each sequence) + \# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), + ] + """ + + # input_ids: b*n,77 + # find index for EOS token + + # Following code is not working if one of the input_ids has multiple EOS tokens (very odd case) + # eos_token_index = torch.where(input_ids == eos_token_id)[1] + # eos_token_index = eos_token_index.to(device=last_hidden_state.device) + + # Create a mask where the EOS tokens are + eos_token_mask = (input_ids == eos_token_id).int() + + # Use argmax to find the last index of the EOS token for each element in the batch + eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine + eos_token_index = eos_token_index.to(device=last_hidden_state.device) + + # get hidden states for EOS token + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index + ] + + # apply projection: projection may be of different dtype than last_hidden_state + pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype)) + pooled_output = pooled_output.to(last_hidden_state.dtype) + + return pooled_output + + def _get_hidden_states_sdxl( + self, + input_ids1: torch.Tensor, + input_ids2: torch.Tensor, + tokenizer1: CLIPTokenizer, + tokenizer2: CLIPTokenizer, + text_encoder1: Union[CLIPTextModel, torch.nn.Module], + text_encoder2: Union[CLIPTextModelWithProjection, torch.nn.Module], + unwrapped_text_encoder2: Optional[CLIPTextModelWithProjection] = None, + ): + # input_ids: b,n,77 -> b*n, 77 + b_size = input_ids1.size()[0] + if input_ids1.size()[1] == 1: + max_token_length = None + else: + max_token_length = input_ids1.size()[1] * input_ids1.size()[2] + input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77 + input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77 + input_ids1 = input_ids1.to(text_encoder1.device) + input_ids2 = input_ids2.to(text_encoder2.device) + + # text_encoder1 + enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True) + hidden_states1 = enc_out["hidden_states"][11] + + # text_encoder2 + enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True) + hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer + + # pool2 = enc_out["text_embeds"] + unwrapped_text_encoder2 = unwrapped_text_encoder2 or text_encoder2 + pool2 = self._pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id) + + # b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280 + n_size = 1 if max_token_length is None else max_token_length // 75 + hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1])) + hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1])) + + if max_token_length is not None: + # bs*3, 77, 768 or 1024 + # encoder1: ... の三連を ... へ戻す + states_list = [hidden_states1[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, tokenizer1.model_max_length): + states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # の後から の前まで + states_list.append(hidden_states1[:, -1].unsqueeze(1)) # + hidden_states1 = torch.cat(states_list, dim=1) + + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [hidden_states2[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, tokenizer2.model_max_length): + chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # の後から 最後の前まで + # this causes an error: + # RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation + # if i > 1: + # for j in range(len(chunk)): # batch_size + # if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり ...のパターン + # chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(hidden_states2[:, -1].unsqueeze(1)) # のどちらか + hidden_states2 = torch.cat(states_list, dim=1) + + # pool はnの最初のものを使う + pool2 = pool2[::n_size] + + return hidden_states1, hidden_states2, pool2 + + def encode_tokens( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] + ) -> List[torch.Tensor]: + """ + Args: + tokenize_strategy: TokenizeStrategy + models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)]. + If text_encoder2 is wrapped by accelerate, unwrapped_text_encoder2 is required + tokens: List of tokens, for text_encoder1 and text_encoder2 + """ + if len(models) == 2: + text_encoder1, text_encoder2 = models + unwrapped_text_encoder2 = None + else: + text_encoder1, text_encoder2, unwrapped_text_encoder2 = models + tokens1, tokens2 = tokens + sdxl_tokenize_strategy = tokenize_strategy # type: SdxlTokenizeStrategy + tokenizer1, tokenizer2 = sdxl_tokenize_strategy.tokenizer1, sdxl_tokenize_strategy.tokenizer2 + + hidden_states1, hidden_states2, pool2 = self._get_hidden_states_sdxl( + tokens1, tokens2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, unwrapped_text_encoder2 + ) + return [hidden_states1, hidden_states2, pool2] + + def encode_tokens_with_weights( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens_list: List[torch.Tensor], + weights_list: List[torch.Tensor], + ) -> List[torch.Tensor]: + hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens_list) + + weights_list = [weights.to(hidden_states1.device) for weights in weights_list] + + # apply weights + if weights_list[0].shape[1] == 1: # no max_token_length + # weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768) + hidden_states1 = hidden_states1 * weights_list[0].squeeze(1).unsqueeze(2) + hidden_states2 = hidden_states2 * weights_list[1].squeeze(1).unsqueeze(2) + else: + # weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768) + for weight, hidden_states in zip(weights_list, [hidden_states1, hidden_states2]): + for i in range(weight.shape[1]): + hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[ + :, i, 1:-1 + ].unsqueeze(-1) + + return [hidden_states1, hidden_states2, pool2] + + +class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): + SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz" + + def __init__( + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + is_weighted: bool = False, + ) -> None: + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted) + + def get_outputs_npz_path(self, image_abs_path: str) -> str: + return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX + + def is_disk_cached_outputs_expected(self, npz_path: str): + if not self.cache_to_disk: + return False + if not os.path.exists(npz_path): + return False + if self.skip_disk_cache_validity_check: + return True + + try: + npz = np.load(npz_path) + if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]: + data = np.load(npz_path) + hidden_state1 = data["hidden_state1"] + hidden_state2 = data["hidden_state2"] + pool2 = data["pool2"] + return [hidden_state1, hidden_state2, pool2] + + def cache_batch_outputs( + self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List + ): + sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy + captions = [info.caption for info in infos] + + if self.is_weighted: + tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(captions) + with torch.no_grad(): + hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, models, tokens_list, weights_list + ) + else: + tokens1, tokens2 = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens( + tokenize_strategy, models, [tokens1, tokens2] + ) + + if hidden_state1.dtype == torch.bfloat16: + hidden_state1 = hidden_state1.float() + if hidden_state2.dtype == torch.bfloat16: + hidden_state2 = hidden_state2.float() + if pool2.dtype == torch.bfloat16: + pool2 = pool2.float() + + hidden_state1 = hidden_state1.cpu().numpy() + hidden_state2 = hidden_state2.cpu().numpy() + pool2 = pool2.cpu().numpy() + + for i, info in enumerate(infos): + hidden_state1_i = hidden_state1[i] + hidden_state2_i = hidden_state2[i] + pool2_i = pool2[i] + + if self.cache_to_disk: + np.savez( + info.text_encoder_outputs_npz, + hidden_state1=hidden_state1_i, + hidden_state2=hidden_state2_i, + pool2=pool2_i, + ) + else: + info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i] diff --git a/library/train_util.py b/library/train_util.py new file mode 100644 index 0000000000000000000000000000000000000000..48f35dd30c92630b0e472d1db46c12216084cfc2 --- /dev/null +++ b/library/train_util.py @@ -0,0 +1,6287 @@ +# common functions for training + +import argparse +import ast +import asyncio +from concurrent.futures import Future, ThreadPoolExecutor +import datetime +import importlib +import json +import logging +import pathlib +import re +import shutil +import time +from typing import ( + Any, + Callable, + Dict, + List, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, +) +from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState +import glob +import math +import os +import random +import hashlib +import subprocess +from io import BytesIO +import toml + +from tqdm import tqdm + +import torch +from .device_utils import init_ipex, clean_memory_on_device +from .strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy, TextEncodingStrategy + +init_ipex() + +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torchvision import transforms +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection +import transformers +from diffusers.optimization import ( + SchedulerType as DiffusersSchedulerType, + TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION, +) +from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION +from diffusers import ( + StableDiffusionPipeline, + DDPMScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDIMScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, + AutoencoderKL, +) +from . import custom_train_functions, sd3_utils +from .original_unet import UNet2DConditionModel +from huggingface_hub import hf_hub_download +import numpy as np +from PIL import Image +import imagesize +import cv2 +import safetensors.torch +from . import model_util as model_util +from . import huggingface_util as huggingface_util +from . import sai_model_spec as sai_model_spec +from . import deepspeed_utils as deepspeed_utils +from .utils import setup_logging, pil_resize + +setup_logging() +import logging + +logger = logging.getLogger(__name__) +# from library.attention_processors import FlashAttnProcessor +# from library.hypernetwork import replace_attentions_for_hypernetwork +from .original_unet import UNet2DConditionModel + +HIGH_VRAM = False + +# checkpointファイル名 +EPOCH_STATE_NAME = "{}-{:06d}-state" +EPOCH_FILE_NAME = "{}-{:06d}" +EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}" +LAST_STATE_NAME = "{}-state" +DEFAULT_EPOCH_NAME = "epoch" +DEFAULT_LAST_OUTPUT_NAME = "last" + +DEFAULT_STEP_NAME = "at" +STEP_STATE_NAME = "{}-step{:05d}-state" +STEP_FILE_NAME = "{}-step{:05d}" +STEP_DIFFUSERS_DIR_NAME = "{}-step{:08d}" + +# region dataset + +IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] + +try: + import pillow_avif + + IMAGE_EXTENSIONS.extend([".avif", ".AVIF"]) +except: + pass + +# JPEG-XL on Linux +try: + from jxlpy import JXLImagePlugin + + IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) +except: + pass + +# JPEG-XL on Windows +try: + import pillow_jxl + + IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) +except: + pass + +IMAGE_TRANSFORMS = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] +) + +TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" +TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" + + +class ImageInfo: + def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: + self.image_key: str = image_key + self.num_repeats: int = num_repeats + self.caption: str = caption + self.is_reg: bool = is_reg + self.absolute_path: str = absolute_path + self.image_size: Tuple[int, int] = None + self.resized_size: Tuple[int, int] = None + self.bucket_reso: Tuple[int, int] = None + self.latents: Optional[torch.Tensor] = None + self.latents_flipped: Optional[torch.Tensor] = None + self.latents_npz: Optional[str] = None # set in cache_latents + self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size + self.latents_crop_ltrb: Optional[Tuple[int, int]] = ( + None # crop left top right bottom in original pixel size, not latents size + ) + self.cond_img_path: Optional[str] = None + self.image: Optional[Image.Image] = None # optional, original PIL Image + self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs + + # new + self.text_encoder_outputs: Optional[List[torch.Tensor]] = None + # old + self.text_encoder_outputs1: Optional[torch.Tensor] = None + self.text_encoder_outputs2: Optional[torch.Tensor] = None + self.text_encoder_pool2: Optional[torch.Tensor] = None + + self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime + + +class BucketManager: + def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: + if max_size is not None: + if max_reso is not None: + assert max_size >= max_reso[0], "the max_size should be larger than the width of max_reso" + assert max_size >= max_reso[1], "the max_size should be larger than the height of max_reso" + if min_size is not None: + assert max_size >= min_size, "the max_size should be larger than the min_size" + + self.no_upscale = no_upscale + if max_reso is None: + self.max_reso = None + self.max_area = None + else: + self.max_reso = max_reso + self.max_area = max_reso[0] * max_reso[1] + self.min_size = min_size + self.max_size = max_size + self.reso_steps = reso_steps + + self.resos = [] + self.reso_to_id = {} + self.buckets = [] # 前処理時は (image_key, image, original size, crop left/top)、学習時は image_key + + def add_image(self, reso, image_or_info): + bucket_id = self.reso_to_id[reso] + self.buckets[bucket_id].append(image_or_info) + + def shuffle(self): + for bucket in self.buckets: + random.shuffle(bucket) + + def sort(self): + # 解像度順にソートする(表示時、メタデータ格納時の見栄えをよくするためだけ)。bucketsも入れ替えてreso_to_idも振り直す + sorted_resos = self.resos.copy() + sorted_resos.sort() + + sorted_buckets = [] + sorted_reso_to_id = {} + for i, reso in enumerate(sorted_resos): + bucket_id = self.reso_to_id[reso] + sorted_buckets.append(self.buckets[bucket_id]) + sorted_reso_to_id[reso] = i + + self.resos = sorted_resos + self.buckets = sorted_buckets + self.reso_to_id = sorted_reso_to_id + + def make_buckets(self): + resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps) + self.set_predefined_resos(resos) + + def set_predefined_resos(self, resos): + # 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく + self.predefined_resos = resos.copy() + self.predefined_resos_set = set(resos) + self.predefined_aspect_ratios = np.array([w / h for w, h in resos]) + + def add_if_new_reso(self, reso): + if reso not in self.reso_to_id: + bucket_id = len(self.resos) + self.reso_to_id[reso] = bucket_id + self.resos.append(reso) + self.buckets.append([]) + # logger.info(reso, bucket_id, len(self.buckets)) + + def round_to_steps(self, x): + x = int(x + 0.5) + return x - x % self.reso_steps + + def select_bucket(self, image_width, image_height): + aspect_ratio = image_width / image_height + if not self.no_upscale: + # 拡大および縮小を行う + # 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する + reso = (image_width, image_height) + if reso in self.predefined_resos_set: + pass + else: + ar_errors = self.predefined_aspect_ratios - aspect_ratio + predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの + reso = self.predefined_resos[predefined_bucket_id] + + ar_reso = reso[0] / reso[1] + if aspect_ratio > ar_reso: # 横が長い→縦を合わせる + scale = reso[1] / image_height + else: + scale = reso[0] / image_width + + resized_size = (int(image_width * scale + 0.5), int(image_height * scale + 0.5)) + # logger.info(f"use predef, {image_width}, {image_height}, {reso}, {resized_size}") + else: + # 縮小のみを行う + if image_width * image_height > self.max_area: + # 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める + resized_width = math.sqrt(self.max_area * aspect_ratio) + resized_height = self.max_area / resized_width + assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal" + + # リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ + # 元のbucketingと同じロジック + b_width_rounded = self.round_to_steps(resized_width) + b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio) + ar_width_rounded = b_width_rounded / b_height_in_wr + + b_height_rounded = self.round_to_steps(resized_height) + b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio) + ar_height_rounded = b_width_in_hr / b_height_rounded + + # logger.info(b_width_rounded, b_height_in_wr, ar_width_rounded) + # logger.info(b_width_in_hr, b_height_rounded, ar_height_rounded) + + if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio): + resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + 0.5)) + else: + resized_size = (int(b_height_rounded * aspect_ratio + 0.5), b_height_rounded) + # logger.info(resized_size) + else: + resized_size = (image_width, image_height) # リサイズは不要 + + # 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする) + bucket_width = resized_size[0] - resized_size[0] % self.reso_steps + bucket_height = resized_size[1] - resized_size[1] % self.reso_steps + # logger.info(f"use arbitrary {image_width}, {image_height}, {resized_size}, {bucket_width}, {bucket_height}") + + reso = (bucket_width, bucket_height) + + self.add_if_new_reso(reso) + + ar_error = (reso[0] / reso[1]) - aspect_ratio + return reso, resized_size, ar_error + + @staticmethod + def get_crop_ltrb(bucket_reso: Tuple[int, int], image_size: Tuple[int, int]): + # Stability AIの前処理に合わせてcrop left/topを計算する。crop rightはflipのaugmentationのために求める + # Calculate crop left/top according to the preprocessing of Stability AI. Crop right is calculated for flip augmentation. + + bucket_ar = bucket_reso[0] / bucket_reso[1] + image_ar = image_size[0] / image_size[1] + if bucket_ar > image_ar: + # bucketのほうが横長→縦を合わせる + resized_width = bucket_reso[1] * image_ar + resized_height = bucket_reso[1] + else: + resized_width = bucket_reso[0] + resized_height = bucket_reso[0] / image_ar + crop_left = (bucket_reso[0] - resized_width) // 2 + crop_top = (bucket_reso[1] - resized_height) // 2 + crop_right = crop_left + resized_width + crop_bottom = crop_top + resized_height + return crop_left, crop_top, crop_right, crop_bottom + + +class BucketBatchIndex(NamedTuple): + bucket_index: int + bucket_batch_size: int + batch_index: int + + +class AugHelper: + # albumentationsへの依存をなくしたがとりあえず同じinterfaceを持たせる + + def __init__(self): + pass + + def color_aug(self, image: np.ndarray): + # self.color_aug_method = albu.OneOf( + # [ + # albu.HueSaturationValue(8, 0, 0, p=0.5), + # albu.RandomGamma((95, 105), p=0.5), + # ], + # p=0.33, + # ) + hue_shift_limit = 8 + + # remove dependency to albumentations + if random.random() <= 0.33: + if random.random() > 0.5: + # hue shift + hsv_img = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + hue_shift = random.uniform(-hue_shift_limit, hue_shift_limit) + if hue_shift < 0: + hue_shift = 180 + hue_shift + hsv_img[:, :, 0] = (hsv_img[:, :, 0] + hue_shift) % 180 + image = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR) + else: + # random gamma + gamma = random.uniform(0.95, 1.05) + image = np.clip(image**gamma, 0, 255).astype(np.uint8) + + return {"image": image} + + def get_augmentor(self, use_color_aug: bool): # -> Optional[Callable[[np.ndarray], Dict[str, np.ndarray]]]: + return self.color_aug if use_color_aug else None + + +class BaseSubset: + def __init__( + self, + image_dir: Optional[str], + alpha_mask: Optional[bool], + num_repeats: int, + shuffle_caption: bool, + caption_separator: str, + keep_tokens: int, + keep_tokens_separator: str, + secondary_separator: Optional[str], + enable_wildcard: bool, + color_aug: bool, + flip_aug: bool, + face_crop_aug_range: Optional[Tuple[float, float]], + random_crop: bool, + caption_dropout_rate: float, + caption_dropout_every_n_epochs: int, + caption_tag_dropout_rate: float, + caption_prefix: Optional[str], + caption_suffix: Optional[str], + token_warmup_min: int, + token_warmup_step: Union[float, int], + custom_attributes: Optional[Dict[str, Any]] = None, + ) -> None: + self.image_dir = image_dir + self.alpha_mask = alpha_mask if alpha_mask is not None else False + self.num_repeats = num_repeats + self.shuffle_caption = shuffle_caption + self.caption_separator = caption_separator + self.keep_tokens = keep_tokens + self.keep_tokens_separator = keep_tokens_separator + self.secondary_separator = secondary_separator + self.enable_wildcard = enable_wildcard + self.color_aug = color_aug + self.flip_aug = flip_aug + self.face_crop_aug_range = face_crop_aug_range + self.random_crop = random_crop + self.caption_dropout_rate = caption_dropout_rate + self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs + self.caption_tag_dropout_rate = caption_tag_dropout_rate + self.caption_prefix = caption_prefix + self.caption_suffix = caption_suffix + + self.token_warmup_min = token_warmup_min # step=0におけるタグの数 + self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる + + self.custom_attributes = custom_attributes if custom_attributes is not None else {} + + self.img_count = 0 + + +class DreamBoothSubset(BaseSubset): + def __init__( + self, + image_dir: str, + is_reg: bool, + class_tokens: Optional[str], + caption_extension: str, + cache_info: bool, + alpha_mask: bool, + num_repeats, + shuffle_caption, + caption_separator: str, + keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, + custom_attributes: Optional[Dict[str, Any]] = None, + ) -> None: + assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" + + super().__init__( + image_dir, + alpha_mask, + num_repeats, + shuffle_caption, + caption_separator, + keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, + custom_attributes=custom_attributes, + ) + + self.is_reg = is_reg + self.class_tokens = class_tokens + self.caption_extension = caption_extension + if self.caption_extension and not self.caption_extension.startswith("."): + self.caption_extension = "." + self.caption_extension + self.cache_info = cache_info + + def __eq__(self, other) -> bool: + if not isinstance(other, DreamBoothSubset): + return NotImplemented + return self.image_dir == other.image_dir + + +class FineTuningSubset(BaseSubset): + def __init__( + self, + image_dir, + metadata_file: str, + alpha_mask: bool, + num_repeats, + shuffle_caption, + caption_separator, + keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, + custom_attributes: Optional[Dict[str, Any]] = None, + ) -> None: + assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" + + super().__init__( + image_dir, + alpha_mask, + num_repeats, + shuffle_caption, + caption_separator, + keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, + custom_attributes=custom_attributes, + ) + + self.metadata_file = metadata_file + + def __eq__(self, other) -> bool: + if not isinstance(other, FineTuningSubset): + return NotImplemented + return self.metadata_file == other.metadata_file + + +class ControlNetSubset(BaseSubset): + def __init__( + self, + image_dir: str, + conditioning_data_dir: str, + caption_extension: str, + cache_info: bool, + num_repeats, + shuffle_caption, + caption_separator, + keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, + custom_attributes: Optional[Dict[str, Any]] = None, + ) -> None: + assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" + + super().__init__( + image_dir, + False, # alpha_mask + num_repeats, + shuffle_caption, + caption_separator, + keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, + custom_attributes=custom_attributes, + ) + + self.conditioning_data_dir = conditioning_data_dir + self.caption_extension = caption_extension + if self.caption_extension and not self.caption_extension.startswith("."): + self.caption_extension = "." + self.caption_extension + self.cache_info = cache_info + + def __eq__(self, other) -> bool: + if not isinstance(other, ControlNetSubset): + return NotImplemented + return self.image_dir == other.image_dir and self.conditioning_data_dir == other.conditioning_data_dir + + +class BaseDataset(torch.utils.data.Dataset): + def __init__( + self, + resolution: Optional[Tuple[int, int]], + network_multiplier: float, + debug_dataset: bool, + ) -> None: + super().__init__() + + # width/height is used when enable_bucket==False + self.width, self.height = (None, None) if resolution is None else resolution + self.network_multiplier = network_multiplier + self.debug_dataset = debug_dataset + + self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = [] + + self.token_padding_disabled = False + self.tag_frequency = {} + self.XTI_layers = None + self.token_strings = None + + self.enable_bucket = False + self.bucket_manager: BucketManager = None # not initialized + self.min_bucket_reso = None + self.max_bucket_reso = None + self.bucket_reso_steps = None + self.bucket_no_upscale = None + self.bucket_info = None # for metadata + + self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ + + self.current_step: int = 0 + self.max_train_steps: int = 0 + self.seed: int = 0 + + # augmentation + self.aug_helper = AugHelper() + + self.image_transforms = IMAGE_TRANSFORMS + + self.image_data: Dict[str, ImageInfo] = {} + self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} + + self.replacements = {} + + # caching + self.caching_mode = None # None, 'latents', 'text' + + self.tokenize_strategy = None + self.text_encoder_output_caching_strategy = None + self.latents_caching_strategy = None + + def set_current_strategies(self): + self.tokenize_strategy = TokenizeStrategy.get_strategy() + self.text_encoder_output_caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy() + self.latents_caching_strategy = LatentsCachingStrategy.get_strategy() + + def adjust_min_max_bucket_reso_by_steps( + self, resolution: Tuple[int, int], min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int + ) -> Tuple[int, int]: + # make min/max bucket reso to be multiple of bucket_reso_steps + if min_bucket_reso % bucket_reso_steps != 0: + adjusted_min_bucket_reso = min_bucket_reso - min_bucket_reso % bucket_reso_steps + logger.warning( + f"min_bucket_reso is adjusted to be multiple of bucket_reso_steps" + f" / min_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {min_bucket_reso} -> {adjusted_min_bucket_reso}" + ) + min_bucket_reso = adjusted_min_bucket_reso + if max_bucket_reso % bucket_reso_steps != 0: + adjusted_max_bucket_reso = max_bucket_reso + bucket_reso_steps - max_bucket_reso % bucket_reso_steps + logger.warning( + f"max_bucket_reso is adjusted to be multiple of bucket_reso_steps" + f" / max_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {max_bucket_reso} -> {adjusted_max_bucket_reso}" + ) + max_bucket_reso = adjusted_max_bucket_reso + + assert ( + min(resolution) >= min_bucket_reso + ), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" + assert ( + max(resolution) <= max_bucket_reso + ), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" + + return min_bucket_reso, max_bucket_reso + + def set_seed(self, seed): + self.seed = seed + + def set_caching_mode(self, mode): + self.caching_mode = mode + + def set_current_epoch(self, epoch): + if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする + if epoch > self.current_epoch: + #logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) + num_epochs = epoch - self.current_epoch + for _ in range(num_epochs): + self.current_epoch += 1 + self.shuffle_buckets() + # self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader? + else: + #logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) + self.current_epoch = epoch + + def set_current_step(self, step): + self.current_step = step + + def set_max_train_steps(self, max_train_steps): + self.max_train_steps = max_train_steps + + def set_tag_frequency(self, dir_name, captions): + frequency_for_dir = self.tag_frequency.get(dir_name, {}) + self.tag_frequency[dir_name] = frequency_for_dir + for caption in captions: + for tag in caption.split(","): + tag = tag.strip() + if tag: + tag = tag.lower() + frequency = frequency_for_dir.get(tag, 0) + frequency_for_dir[tag] = frequency + 1 + + def disable_token_padding(self): + self.token_padding_disabled = True + + def enable_XTI(self, layers=None, token_strings=None): + self.XTI_layers = layers + self.token_strings = token_strings + + def add_replacement(self, str_from, str_to): + self.replacements[str_from] = str_to + + def process_caption(self, subset: BaseSubset, caption): + # caption に prefix/suffix を付ける + if subset.caption_prefix: + caption = subset.caption_prefix + " " + caption + if subset.caption_suffix: + caption = caption + " " + subset.caption_suffix + + # dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い + is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate + is_drop_out = ( + is_drop_out + or subset.caption_dropout_every_n_epochs > 0 + and self.current_epoch % subset.caption_dropout_every_n_epochs == 0 + ) + + if is_drop_out: + caption = "" + else: + # process wildcards + if subset.enable_wildcard: + # if caption is multiline, random choice one line + if "\n" in caption: + caption = random.choice(caption.split("\n")) + + # wildcard is like '{aaa|bbb|ccc...}' + # escape the curly braces like {{ or }} + replacer1 = "⦅" + replacer2 = "⦆" + while replacer1 in caption or replacer2 in caption: + replacer1 += "⦅" + replacer2 += "⦆" + + caption = caption.replace("{{", replacer1).replace("}}", replacer2) + + # replace the wildcard + def replace_wildcard(match): + return random.choice(match.group(1).split("|")) + + caption = re.sub(r"\{([^}]+)\}", replace_wildcard, caption) + + # unescape the curly braces + caption = caption.replace(replacer1, "{").replace(replacer2, "}") + else: + # if caption is multiline, use the first line + caption = caption.split("\n")[0] + + if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: + fixed_tokens = [] + flex_tokens = [] + fixed_suffix_tokens = [] + if ( + hasattr(subset, "keep_tokens_separator") + and subset.keep_tokens_separator + and subset.keep_tokens_separator in caption + ): + fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1) + if subset.keep_tokens_separator in flex_part: + flex_part, fixed_suffix_part = flex_part.split(subset.keep_tokens_separator, 1) + fixed_suffix_tokens = [t.strip() for t in fixed_suffix_part.split(subset.caption_separator) if t.strip()] + + fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()] + flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()] + else: + tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)] + flex_tokens = tokens[:] + if subset.keep_tokens > 0: + fixed_tokens = flex_tokens[: subset.keep_tokens] + flex_tokens = tokens[subset.keep_tokens :] + + if subset.token_warmup_step < 1: # 初回に上書きする + subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) + if subset.token_warmup_step and self.current_step < subset.token_warmup_step: + tokens_len = ( + math.floor( + (self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step)) + ) + + subset.token_warmup_min + ) + flex_tokens = flex_tokens[:tokens_len] + + def dropout_tags(tokens): + if subset.caption_tag_dropout_rate <= 0: + return tokens + l = [] + for token in tokens: + if random.random() >= subset.caption_tag_dropout_rate: + l.append(token) + return l + + if subset.shuffle_caption: + random.shuffle(flex_tokens) + + flex_tokens = dropout_tags(flex_tokens) + + caption = ", ".join(fixed_tokens + flex_tokens + fixed_suffix_tokens) + + # process secondary separator + if subset.secondary_separator: + caption = caption.replace(subset.secondary_separator, subset.caption_separator) + + # textual inversion対応 + for str_from, str_to in self.replacements.items(): + if str_from == "": + # replace all + if type(str_to) == list: + caption = random.choice(str_to) + else: + caption = str_to + else: + caption = caption.replace(str_from, str_to) + + return caption + + def get_input_ids(self, caption, tokenizer=None): + if tokenizer is None: + tokenizer = self.tokenizers[0] + + input_ids = tokenizer( + caption, padding="max_length", truncation=True, max_length=self.tokenizer_max_length, return_tensors="pt" + ).input_ids + + if self.tokenizer_max_length > tokenizer.model_max_length: + input_ids = input_ids.squeeze(0) + iids_list = [] + if tokenizer.pad_token_id == tokenizer.eos_token_id: + # v1 + # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する + # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に + for i in range( + 1, self.tokenizer_max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2 + ): # (1, 152, 75) + ids_chunk = ( + input_ids[0].unsqueeze(0), + input_ids[i : i + tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) + ids_chunk = torch.cat(ids_chunk) + iids_list.append(ids_chunk) + else: + # v2 or SDXL + # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する + for i in range(1, self.tokenizer_max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): + ids_chunk = ( + input_ids[0].unsqueeze(0), # BOS + input_ids[i : i + tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) # PAD or EOS + ids_chunk = torch.cat(ids_chunk) + + # 末尾が または の場合は、何もしなくてよい + # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) + if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id: + ids_chunk[-1] = tokenizer.eos_token_id + # 先頭が ... の場合は ... に変える + if ids_chunk[1] == tokenizer.pad_token_id: + ids_chunk[1] = tokenizer.eos_token_id + + iids_list.append(ids_chunk) + + input_ids = torch.stack(iids_list) # 3,77 + return input_ids + + def register_image(self, info: ImageInfo, subset: BaseSubset): + self.image_data[info.image_key] = info + self.image_to_subset[info.image_key] = subset + + def make_buckets(self): + """ + bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) + min_size and max_size are ignored when enable_bucket is False + """ + logger.info("loading image sizes.") + for info in tqdm(self.image_data.values()): + if info.image_size is None: + info.image_size = self.get_image_size(info.absolute_path) + + if self.enable_bucket: + logger.info("make buckets") + else: + logger.info("prepare dataset") + + # bucketを作成し、画像をbucketに振り分ける + if self.enable_bucket: + if self.bucket_manager is None: # fine tuningの場合でmetadataに定義がある場合は、すでに初期化済み + self.bucket_manager = BucketManager( + self.bucket_no_upscale, + (self.width, self.height), + self.min_bucket_reso, + self.max_bucket_reso, + self.bucket_reso_steps, + ) + if not self.bucket_no_upscale: + self.bucket_manager.make_buckets() + else: + logger.warning( + "min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます" + ) + + img_ar_errors = [] + for image_info in self.image_data.values(): + image_width, image_height = image_info.image_size + image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket( + image_width, image_height + ) + + # logger.info(image_info.image_key, image_info.bucket_reso) + img_ar_errors.append(abs(ar_error)) + + self.bucket_manager.sort() + else: + self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None) + self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひとつの固定サイズbucketのみ + for image_info in self.image_data.values(): + image_width, image_height = image_info.image_size + image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height) + + for image_info in self.image_data.values(): + for _ in range(image_info.num_repeats): + self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key) + + # bucket情報を表示、格納する + if self.enable_bucket: + self.bucket_info = {"buckets": {}} + logger.info("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") + for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)): + count = len(bucket) + if count > 0: + self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)} + logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}") + + img_ar_errors = np.array(img_ar_errors) + mean_img_ar_error = np.mean(np.abs(img_ar_errors)) + self.bucket_info["mean_img_ar_error"] = mean_img_ar_error + logger.info(f"mean ar error (without repeats): {mean_img_ar_error}") + + # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる + self.buckets_indices: List[BucketBatchIndex] = [] + for bucket_index, bucket in enumerate(self.bucket_manager.buckets): + batch_count = int(math.ceil(len(bucket) / self.batch_size)) + for batch_index in range(batch_count): + self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index)) + + self.shuffle_buckets() + self._length = len(self.buckets_indices) + + def shuffle_buckets(self): + # set random seed for this epoch + random.seed(self.seed + self.current_epoch) + + random.shuffle(self.buckets_indices) + self.bucket_manager.shuffle() + + def verify_bucket_reso_steps(self, min_steps: int): + assert self.bucket_reso_steps is None or self.bucket_reso_steps % min_steps == 0, ( + f"bucket_reso_steps is {self.bucket_reso_steps}. it must be divisible by {min_steps}.\n" + + f"bucket_reso_stepsが{self.bucket_reso_steps}です。{min_steps}で割り切れる必要があります" + ) + + def is_latent_cacheable(self): + return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) + + def is_text_encoder_output_cacheable(self): + return all( + [ + not ( + subset.caption_dropout_rate > 0 + or subset.shuffle_caption + or subset.token_warmup_step > 0 + or subset.caption_tag_dropout_rate > 0 + ) + for subset in self.subsets + ] + ) + + def new_cache_latents(self, model: Any, accelerator: Accelerator): + r""" + a brand new method to cache latents. This method caches latents with caching strategy. + normal cache_latents method is used by default, but this method is used when caching strategy is specified. + """ + logger.info("caching latents with caching strategy.") + caching_strategy = LatentsCachingStrategy.get_strategy() + image_infos = list(self.image_data.values()) + + # sort by resolution + image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) + + # split by resolution and some conditions + class Condition: + def __init__(self, reso, flip_aug, alpha_mask, random_crop): + self.reso = reso + self.flip_aug = flip_aug + self.alpha_mask = alpha_mask + self.random_crop = random_crop + + def __eq__(self, other): + return ( + self.reso == other.reso + and self.flip_aug == other.flip_aug + and self.alpha_mask == other.alpha_mask + and self.random_crop == other.random_crop + ) + + batch: List[ImageInfo] = [] + current_condition = None + + # support multiple-gpus + num_processes = accelerator.num_processes + process_index = accelerator.process_index + + # define a function to submit a batch to cache + def submit_batch(batch, cond): + for info in batch: + if info.image is not None and isinstance(info.image, Future): + info.image = info.image.result() # future to image + caching_strategy.cache_batch_latents(model, batch, cond.flip_aug, cond.alpha_mask, cond.random_crop) + + # remove image from memory + for info in batch: + info.image = None + + # define ThreadPoolExecutor to load images in parallel + max_workers = min(os.cpu_count(), len(image_infos)) + max_workers = max(1, max_workers // num_processes) # consider multi-gpu + max_workers = min(max_workers, caching_strategy.batch_size) # max_workers should be less than batch_size + executor = ThreadPoolExecutor(max_workers) + + try: + # iterate images + logger.info("caching latents...") + for i, info in enumerate(tqdm(image_infos)): + subset = self.image_to_subset[info.image_key] + + if info.latents_npz is not None: # fine tuning dataset + continue + + # check disk cache exists and size of latents + if caching_strategy.cache_to_disk: + # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix + info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) + + # if the modulo of num_processes is not equal to process_index, skip caching + # this makes each process cache different latents + if i % num_processes != process_index: + continue + + # print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") + + cache_available = caching_strategy.is_disk_cached_latents_expected( + info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask + ) + if cache_available: # do not add to batch + continue + + # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty + condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) + if len(batch) > 0 and current_condition != condition: + submit_batch(batch, current_condition) + batch = [] + + if info.image is None: + # load image in parallel + info.image = executor.submit(load_image, info.absolute_path, condition.alpha_mask) + + batch.append(info) + current_condition = condition + + # if number of data in batch is enough, flush the batch + if len(batch) >= caching_strategy.batch_size: + submit_batch(batch, current_condition) + batch = [] + current_condition = None + + if len(batch) > 0: + submit_batch(batch, current_condition) + + finally: + executor.shutdown() + + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): + # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと + logger.info("caching latents.") + + image_infos = list(self.image_data.values()) + + # sort by resolution + image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) + + # split by resolution and some conditions + class Condition: + def __init__(self, reso, flip_aug, alpha_mask, random_crop): + self.reso = reso + self.flip_aug = flip_aug + self.alpha_mask = alpha_mask + self.random_crop = random_crop + + def __eq__(self, other): + return ( + self.reso == other.reso + and self.flip_aug == other.flip_aug + and self.alpha_mask == other.alpha_mask + and self.random_crop == other.random_crop + ) + + batches: List[Tuple[Condition, List[ImageInfo]]] = [] + batch: List[ImageInfo] = [] + current_condition = None + + logger.info("checking cache validity...") + for info in tqdm(image_infos): + subset = self.image_to_subset[info.image_key] + + if info.latents_npz is not None: # fine tuning dataset + continue + + # check disk cache exists and size of latents + if cache_to_disk: + info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix + if not is_main_process: # store to info only + continue + + cache_available = is_disk_cached_latents_is_expected( + info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask + ) + + if cache_available: # do not add to batch + continue + + # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty + condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) + if len(batch) > 0 and current_condition != condition: + batches.append((current_condition, batch)) + batch = [] + + batch.append(info) + current_condition = condition + + # if number of data in batch is enough, flush the batch + if len(batch) >= vae_batch_size: + batches.append((current_condition, batch)) + batch = [] + current_condition = None + + if len(batch) > 0: + batches.append((current_condition, batch)) + + if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only + return + + # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded + logger.info("caching latents...") + for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): + cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) + + def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator): + r""" + a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy. + """ + tokenize_strategy = TokenizeStrategy.get_strategy() + text_encoding_strategy = TextEncodingStrategy.get_strategy() + caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy() + batch_size = caching_strategy.batch_size or self.batch_size + + logger.info("caching Text Encoder outputs with caching strategy.") + image_infos = list(self.image_data.values()) + + # split by resolution + batches = [] + batch = [] + + # support multiple-gpus + num_processes = accelerator.num_processes + process_index = accelerator.process_index + + logger.info("checking cache validity...") + for i, info in enumerate(tqdm(image_infos)): + # check disk cache exists and size of text encoder outputs + if caching_strategy.cache_to_disk: + te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path) + info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability + + # if the modulo of num_processes is not equal to process_index, skip caching + # this makes each process cache different text encoder outputs + if i % num_processes != process_index: + continue + + cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz) + if cache_available: # do not add to batch + continue + + batch.append(info) + + # if number of data in batch is enough, flush the batch + if len(batch) >= batch_size: + batches.append(batch) + batch = [] + + if len(batch) > 0: + batches.append(batch) + + if len(batches) == 0: + logger.info("no Text Encoder outputs to cache") + return + + # iterate batches + logger.info("caching Text Encoder outputs...") + for batch in tqdm(batches, smoothing=1, total=len(batches)): + # cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + caching_strategy.cache_batch_outputs(tokenize_strategy, models, text_encoding_strategy, batch) + + # if weight_dtype is specified, Text Encoder itself and output will be converted to the dtype + # this method is only for SDXL, but it should be implemented here because it needs to be a method of dataset + # to support SD1/2, it needs a flag for v2, but it is postponed + def cache_text_encoder_outputs( + self, tokenizers, text_encoders, device, output_dtype, cache_to_disk=False, is_main_process=True + ): + assert len(tokenizers) == 2, "only support SDXL" + return self.cache_text_encoder_outputs_common( + tokenizers, text_encoders, [device, device], output_dtype, [output_dtype], cache_to_disk, is_main_process + ) + + # same as above, but for SD3 + def cache_text_encoder_outputs_sd3( + self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None + ): + return self.cache_text_encoder_outputs_common( + [tokenizer], + text_encoders, + devices, + output_dtype, + te_dtypes, + cache_to_disk, + is_main_process, + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3, + batch_size, + ) + + def cache_text_encoder_outputs_common( + self, + tokenizers, + text_encoders, + devices, + output_dtype, + te_dtypes, + cache_to_disk=False, + is_main_process=True, + file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX, + batch_size=None, + ): + # latentsのキャッシュと同様に、ディスクへのキャッシュに対応する + # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと + logger.info("caching text encoder outputs.") + + tokenize_strategy = TokenizeStrategy.get_strategy() + + if batch_size is None: + batch_size = self.batch_size + + image_infos = list(self.image_data.values()) + + logger.info("checking cache existence...") + image_infos_to_cache = [] + for info in tqdm(image_infos): + # subset = self.image_to_subset[info.image_key] + if cache_to_disk: + te_out_npz = os.path.splitext(info.absolute_path)[0] + file_suffix + info.text_encoder_outputs_npz = te_out_npz + + if not is_main_process: # store to info only + continue + + if os.path.exists(te_out_npz): + # TODO check varidity of cache here + continue + + image_infos_to_cache.append(info) + + if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only + return + + # prepare tokenizers and text encoders + for text_encoder, device, te_dtype in zip(text_encoders, devices, te_dtypes): + text_encoder.to(device) + if te_dtype is not None: + text_encoder.to(dtype=te_dtype) + + # create batch + is_sd3 = len(tokenizers) == 1 + batch = [] + batches = [] + for info in image_infos_to_cache: + if not is_sd3: + input_ids1 = self.get_input_ids(info.caption, tokenizers[0]) + input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) + batch.append((info, input_ids1, input_ids2)) + else: + l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(info.caption) + batch.append((info, l_tokens, g_tokens, t5_tokens)) + + if len(batch) >= batch_size: + batches.append(batch) + batch = [] + + if len(batch) > 0: + batches.append(batch) + + # iterate batches: call text encoder and cache outputs for memory or disk + logger.info("caching text encoder outputs...") + if not is_sd3: + for batch in tqdm(batches): + infos, input_ids1, input_ids2 = zip(*batch) + input_ids1 = torch.stack(input_ids1, dim=0) + input_ids2 = torch.stack(input_ids2, dim=0) + cache_batch_text_encoder_outputs( + infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, output_dtype + ) + else: + for batch in tqdm(batches): + infos, l_tokens, g_tokens, t5_tokens = zip(*batch) + + # stack tokens + # l_tokens = [tokens[0] for tokens in l_tokens] + # g_tokens = [tokens[0] for tokens in g_tokens] + # t5_tokens = [tokens[0] for tokens in t5_tokens] + + cache_batch_text_encoder_outputs_sd3( + infos, + tokenizers[0], + text_encoders, + self.max_token_length, + cache_to_disk, + (l_tokens, g_tokens, t5_tokens), + output_dtype, + ) + + def get_image_size(self, image_path): + # return imagesize.get(image_path) + image_size = imagesize.get(image_path) + if image_size[0] <= 0: + # imagesize doesn't work for some images, so use PIL as a fallback + try: + with Image.open(image_path) as img: + image_size = img.size + except Exception as e: + logger.warning(f"failed to get image size: {image_path}, error: {e}") + image_size = (0, 0) + return image_size + + def load_image_with_face_info(self, subset: BaseSubset, image_path: str, alpha_mask=False): + img = load_image(image_path, alpha_mask) + + face_cx = face_cy = face_w = face_h = 0 + if subset.face_crop_aug_range is not None: + tokens = os.path.splitext(os.path.basename(image_path))[0].split("_") + if len(tokens) >= 5: + face_cx = int(tokens[-4]) + face_cy = int(tokens[-3]) + face_w = int(tokens[-2]) + face_h = int(tokens[-1]) + + return img, face_cx, face_cy, face_w, face_h + + # いい感じに切り出す + def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h): + height, width = image.shape[0:2] + if height == self.height and width == self.width: + return image + + # 画像サイズはsizeより大きいのでリサイズする + face_size = max(face_w, face_h) + size = min(self.height, self.width) # 短いほう + min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率) + min_scale = min(1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ + max_scale = min(1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ + if min_scale >= max_scale: # range指定がmin==max + scale = min_scale + else: + scale = random.uniform(min_scale, max_scale) + + nh = int(height * scale + 0.5) + nw = int(width * scale + 0.5) + assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" + image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) + face_cx = int(face_cx * scale + 0.5) + face_cy = int(face_cy * scale + 0.5) + height, width = nh, nw + + # 顔を中心として448*640とかへ切り出す + for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))): + p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置 + + if subset.random_crop: + # 背景も含めるために顔を中心に置く確率を高めつつずらす + range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう + p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数 + else: + # range指定があるときのみ、すこしだけランダムに(わりと適当) + if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]: + if face_size > size // 10 and face_size >= 40: + p1 = p1 + random.randint(-face_size // 20, +face_size // 20) + + p1 = max(0, min(p1, length - target_size)) + + if axis == 0: + image = image[p1 : p1 + target_size, :] + else: + image = image[:, p1 : p1 + target_size] + + return image + + def __len__(self): + return self._length + + def __getitem__(self, index): + bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index] + bucket_batch_size = self.buckets_indices[index].bucket_batch_size + image_index = self.buckets_indices[index].batch_index * bucket_batch_size + + if self.caching_mode is not None: # return batch for latents/text encoder outputs caching + return self.get_item_for_caching(bucket, bucket_batch_size, image_index) + + loss_weights = [] + captions = [] + input_ids_list = [] + latents_list = [] + alpha_mask_list = [] + images = [] + original_sizes_hw = [] + crop_top_lefts = [] + target_sizes_hw = [] + flippeds = [] # 変数名が微妙 + text_encoder_outputs_list = [] + custom_attributes = [] + + for image_key in bucket[image_index : image_index + bucket_batch_size]: + image_info = self.image_data[image_key] + subset = self.image_to_subset[image_key] + + custom_attributes.append(subset.custom_attributes) + + # in case of fine tuning, is_reg is always False + loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) + + flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance + + # image/latentsを処理する + if image_info.latents is not None: # cache_latents=Trueの場合 + original_size = image_info.latents_original_size + crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped + if not flipped: + latents = image_info.latents + alpha_mask = image_info.alpha_mask + else: + latents = image_info.latents_flipped + alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1]) + + image = None + elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 + latents, original_size, crop_ltrb, flipped_latents, alpha_mask = ( + self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso) + ) + if flipped: + latents = flipped_latents + alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem + del flipped_latents + latents = torch.FloatTensor(latents) + if alpha_mask is not None: + alpha_mask = torch.FloatTensor(alpha_mask) + + image = None + else: + # 画像を読み込み、必要ならcropする + img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info( + subset, image_info.absolute_path, subset.alpha_mask + ) + im_h, im_w = img.shape[0:2] + + if self.enable_bucket: + img, original_size, crop_ltrb = trim_and_resize_if_required( + subset.random_crop, img, image_info.bucket_reso, image_info.resized_size + ) + else: + if face_cx > 0: # 顔位置情報あり + img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h) + elif im_h > self.height or im_w > self.width: + assert ( + subset.random_crop + ), f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}" + if im_h > self.height: + p = random.randint(0, im_h - self.height) + img = img[p : p + self.height] + if im_w > self.width: + p = random.randint(0, im_w - self.width) + img = img[:, p : p + self.width] + + im_h, im_w = img.shape[0:2] + assert ( + im_h == self.height and im_w == self.width + ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" + + original_size = [im_w, im_h] + crop_ltrb = (0, 0, 0, 0) + + # augmentation + aug = self.aug_helper.get_augmentor(subset.color_aug) + if aug is not None: + # augment RGB channels only + img_rgb = img[:, :, :3] + img_rgb = aug(image=img_rgb)["image"] + img[:, :, :3] = img_rgb + + if flipped: + img = img[:, ::-1, :].copy() # copy to avoid negative stride problem + + if subset.alpha_mask: + if img.shape[2] == 4: + alpha_mask = img[:, :, 3] # [H,W] + alpha_mask = alpha_mask.astype(np.float32) / 255.0 # 0.0~1.0 + alpha_mask = torch.FloatTensor(alpha_mask) + else: + alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32) + else: + alpha_mask = None + + img = img[:, :, :3] # remove alpha channel + + latents = None + image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる + del img + + images.append(image) + latents_list.append(latents) + alpha_mask_list.append(alpha_mask) + + target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8) + + if not flipped: + crop_left_top = (crop_ltrb[0], crop_ltrb[1]) + else: + # crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image + crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1]) + + original_sizes_hw.append((int(original_size[1]), int(original_size[0]))) + crop_top_lefts.append((int(crop_left_top[1]), int(crop_left_top[0]))) + target_sizes_hw.append((int(target_size[1]), int(target_size[0]))) + flippeds.append(flipped) + + # captionとtext encoder outputを処理する + caption = image_info.caption # default + + tokenization_required = ( + self.text_encoder_output_caching_strategy is None or self.text_encoder_output_caching_strategy.is_partial + ) + text_encoder_outputs = None + input_ids = None + + if image_info.text_encoder_outputs is not None: + # cached + text_encoder_outputs = image_info.text_encoder_outputs + elif image_info.text_encoder_outputs_npz is not None: + # on disk + text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz( + image_info.text_encoder_outputs_npz + ) + else: + tokenization_required = True + text_encoder_outputs_list.append(text_encoder_outputs) + + if tokenization_required: + caption = self.process_caption(subset, image_info.caption) + input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension + # if self.XTI_layers: + # caption_layer = [] + # for layer in self.XTI_layers: + # token_strings_from = " ".join(self.token_strings) + # token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) + # caption_ = caption.replace(token_strings_from, token_strings_to) + # caption_layer.append(caption_) + # captions.append(caption_layer) + # else: + # captions.append(caption) + + # if not self.token_padding_disabled: # this option might be omitted in future + # # TODO get_input_ids must support SD3 + # if self.XTI_layers: + # token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) + # else: + # token_caption = self.get_input_ids(caption, self.tokenizers[0]) + # input_ids_list.append(token_caption) + + # if len(self.tokenizers) > 1: + # if self.XTI_layers: + # token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) + # else: + # token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) + # input_ids2_list.append(token_caption2) + + input_ids_list.append(input_ids) + captions.append(caption) + + def none_or_stack_elements(tensors_list, converter): + # [[clip_l, clip_g, t5xxl], [clip_l, clip_g, t5xxl], ...] -> [torch.stack(clip_l), torch.stack(clip_g), torch.stack(t5xxl)] + if len(tensors_list) == 0 or tensors_list[0] == None or len(tensors_list[0]) == 0 or tensors_list[0][0] is None: + return None + return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))] + + # set example + example = {} + example["custom_attributes"] = custom_attributes # may be list of empty dict + example["loss_weights"] = torch.FloatTensor(loss_weights) + example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor) + example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x) + + # if one of alpha_masks is not None, we need to replace None with ones + none_or_not = [x is None for x in alpha_mask_list] + if all(none_or_not): + example["alpha_masks"] = None + elif any(none_or_not): + for i in range(len(alpha_mask_list)): + if alpha_mask_list[i] is None: + if images[i] is not None: + alpha_mask_list[i] = torch.ones((images[i].shape[1], images[i].shape[2]), dtype=torch.float32) + else: + alpha_mask_list[i] = torch.ones( + (latents_list[i].shape[1] * 8, latents_list[i].shape[2] * 8), dtype=torch.float32 + ) + example["alpha_masks"] = torch.stack(alpha_mask_list) + else: + example["alpha_masks"] = torch.stack(alpha_mask_list) + + if images[0] is not None: + images = torch.stack(images) + images = images.to(memory_format=torch.contiguous_format).float() + else: + images = None + example["images"] = images + + example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None + example["captions"] = captions + + example["original_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in original_sizes_hw]) + example["crop_top_lefts"] = torch.stack([torch.LongTensor(x) for x in crop_top_lefts]) + example["target_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in target_sizes_hw]) + example["flippeds"] = flippeds + + example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions)) + + if self.debug_dataset: + example["image_keys"] = bucket[image_index : image_index + self.batch_size] + return example + + def get_item_for_caching(self, bucket, bucket_batch_size, image_index): + captions = [] + images = [] + input_ids1_list = [] + input_ids2_list = [] + absolute_paths = [] + resized_sizes = [] + bucket_reso = None + flip_aug = None + alpha_mask = None + random_crop = None + + for image_key in bucket[image_index : image_index + bucket_batch_size]: + image_info = self.image_data[image_key] + subset = self.image_to_subset[image_key] + + if flip_aug is None: + flip_aug = subset.flip_aug + alpha_mask = subset.alpha_mask + random_crop = subset.random_crop + bucket_reso = image_info.bucket_reso + else: + # TODO そもそも混在してても動くようにしたほうがいい + assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch" + assert alpha_mask == subset.alpha_mask, "alpha_mask must be same in a batch" + assert random_crop == subset.random_crop, "random_crop must be same in a batch" + assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch" + + caption = image_info.caption # TODO cache some patterns of dropping, shuffling, etc. + + if self.caching_mode == "latents": + image = load_image(image_info.absolute_path) + else: + image = None + + if self.caching_mode == "text": + input_ids1 = self.get_input_ids(caption, self.tokenizers[0]) + input_ids2 = self.get_input_ids(caption, self.tokenizers[1]) + else: + input_ids1 = None + input_ids2 = None + + captions.append(caption) + images.append(image) + input_ids1_list.append(input_ids1) + input_ids2_list.append(input_ids2) + absolute_paths.append(image_info.absolute_path) + resized_sizes.append(image_info.resized_size) + + example = {} + + if images[0] is None: + images = None + example["images"] = images + + example["captions"] = captions + example["input_ids1_list"] = input_ids1_list + example["input_ids2_list"] = input_ids2_list + example["absolute_paths"] = absolute_paths + example["resized_sizes"] = resized_sizes + example["flip_aug"] = flip_aug + example["alpha_mask"] = alpha_mask + example["random_crop"] = random_crop + example["bucket_reso"] = bucket_reso + return example + + +class DreamBoothDataset(BaseDataset): + IMAGE_INFO_CACHE_FILE = "metadata_cache.json" + + def __init__( + self, + subsets: Sequence[DreamBoothSubset], + batch_size: int, + resolution, + network_multiplier: float, + enable_bucket: bool, + min_bucket_reso: int, + max_bucket_reso: int, + bucket_reso_steps: int, + bucket_no_upscale: bool, + prior_loss_weight: float, + debug_dataset: bool, + ) -> None: + super().__init__(resolution, network_multiplier, debug_dataset) + + assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" + + self.batch_size = batch_size + self.size = min(self.width, self.height) # 短いほう + self.prior_loss_weight = prior_loss_weight + self.latents_cache = None + + self.enable_bucket = enable_bucket + if self.enable_bucket: + min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps( + resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps + ) + self.min_bucket_reso = min_bucket_reso + self.max_bucket_reso = max_bucket_reso + self.bucket_reso_steps = bucket_reso_steps + self.bucket_no_upscale = bucket_no_upscale + else: + self.min_bucket_reso = None + self.max_bucket_reso = None + self.bucket_reso_steps = None # この情報は使われない + self.bucket_no_upscale = False + + def read_caption(img_path, caption_extension, enable_wildcard): + # captionの候補ファイル名を作る + base_name = os.path.splitext(img_path)[0] + base_name_face_det = base_name + tokens = base_name.split("_") + if len(tokens) >= 5: + base_name_face_det = "_".join(tokens[:-4]) + cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension] + + caption = None + for cap_path in cap_paths: + if os.path.isfile(cap_path): + with open(cap_path, "rt", encoding="utf-8") as f: + try: + lines = f.readlines() + except UnicodeDecodeError as e: + logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") + raise e + assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" + if enable_wildcard: + caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結 + else: + caption = lines[0].strip() + break + return caption + + def load_dreambooth_dir(subset: DreamBoothSubset): + if not os.path.isdir(subset.image_dir): + logger.warning(f"not directory: {subset.image_dir}") + return [], [], [] + + info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE) + use_cached_info_for_subset = subset.cache_info + if use_cached_info_for_subset: + logger.info( + f"using cached image info for this subset / このサブセットで、キャッシュされた画像情報を使います: {info_cache_file}" + ) + if not os.path.isfile(info_cache_file): + logger.warning( + f"image info file not found. You can ignore this warning if this is the first time to use this subset" + + " / キャッシュファイルが見つかりませんでした。初回実行時はこの警告を無視してください: {metadata_file}" + ) + use_cached_info_for_subset = False + + if use_cached_info_for_subset: + # json: {`img_path`:{"caption": "caption...", "resolution": [width, height]}, ...} + with open(info_cache_file, "r", encoding="utf-8") as f: + metas = json.load(f) + img_paths = list(metas.keys()) + sizes = [meta["resolution"] for meta in metas.values()] + + # we may need to check image size and existence of image files, but it takes time, so user should check it before training + else: + img_paths = glob_images(subset.image_dir, "*") + sizes = [None] * len(img_paths) + + # new caching: get image size from cache files + strategy = LatentsCachingStrategy.get_strategy() + if strategy is not None: + logger.info("get image size from name of cache files") + + # make image path to npz path mapping + npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix)) + npz_paths.sort( + key=lambda item: item.rsplit("_", maxsplit=2)[0] + ) # sort by name excluding resolution and cache_suffix + npz_path_index = 0 + + size_set_count = 0 + for i, img_path in enumerate(tqdm(img_paths)): + l = len(os.path.splitext(img_path)[0]) # remove extension + found = False + while npz_path_index < len(npz_paths): # until found or end of npz_paths + # npz_paths are sorted, so if npz_path > img_path, img_path is not found + if npz_paths[npz_path_index][:l] > img_path[:l]: + break + if npz_paths[npz_path_index][:l] == img_path[:l]: # found + found = True + break + npz_path_index += 1 # next npz_path + + if found: + w, h = strategy.get_image_size_from_disk_cache_path(img_path, npz_paths[npz_path_index]) + else: + w, h = None, None + + if w is not None and h is not None: + sizes[i] = [w, h] + size_set_count += 1 + logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") + + logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") + + if use_cached_info_for_subset: + captions = [meta["caption"] for meta in metas.values()] + missing_captions = [img_path for img_path, caption in zip(img_paths, captions) if caption is None or caption == ""] + else: + # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う + captions = [] + missing_captions = [] + for img_path in tqdm(img_paths, desc="read caption"): + cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard) + if cap_for_img is None and subset.class_tokens is None: + logger.warning( + f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}" + ) + captions.append("") + missing_captions.append(img_path) + else: + if cap_for_img is None: + final_caption = subset.class_tokens + missing_captions.append(img_path) + else: + final_caption = cap_for_img + if subset.class_tokens: + final_caption = f"{subset.class_tokens} {final_caption}" # Prepend class token to the caption + + captions.append(final_caption) + + self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録 + + logger.info(f"Found captions for {len(captions) - len(missing_captions)} images.") + + if missing_captions: + number_of_missing_captions = len(missing_captions) + number_of_missing_captions_to_show = 5 + remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show + + logger.warning( + f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。" + ) + for i, missing_caption in enumerate(missing_captions): + if i >= number_of_missing_captions_to_show: + logger.warning(missing_caption + f"... and {remaining_missing_captions} more") + break + logger.warning(missing_caption) + + if not use_cached_info_for_subset and subset.cache_info: + logger.info(f"cache image info for / 画像情報をキャッシュします : {info_cache_file}") + sizes = [self.get_image_size(img_path) for img_path in tqdm(img_paths, desc="get image size")] + matas = {} + for img_path, caption, size in zip(img_paths, captions, sizes): + matas[img_path] = {"caption": caption, "resolution": list(size)} + with open(info_cache_file, "w", encoding="utf-8") as f: + json.dump(matas, f, ensure_ascii=False, indent=2) + logger.info(f"cache image info done for / 画像情報を出力しました : {info_cache_file}") + + # if sizes are not set, image size will be read in make_buckets + return img_paths, captions, sizes + + logger.info("prepare images.") + num_train_images = 0 + num_reg_images = 0 + reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = [] + for subset in subsets: + if subset.num_repeats < 1: + logger.warning( + f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" + ) + continue + + if subset in self.subsets: + logger.warning( + f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します" + ) + continue + + img_paths, captions, sizes = load_dreambooth_dir(subset) + if len(img_paths) < 1: + logger.warning( + f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します" + ) + continue + + if subset.is_reg: + num_reg_images += subset.num_repeats * len(img_paths) + else: + num_train_images += subset.num_repeats * len(img_paths) + + for img_path, caption, size in zip(img_paths, captions, sizes): + info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) + if size is not None: + info.image_size = size + if subset.is_reg: + reg_infos.append((info, subset)) + else: + self.register_image(info, subset) + + subset.img_count = len(img_paths) + self.subsets.append(subset) + + logger.info(f"{num_train_images} train images with repeating.") + self.num_train_images = num_train_images + + logger.info(f"{num_reg_images} reg images.") + if num_train_images < num_reg_images: + logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") + + if num_reg_images == 0: + logger.warning("no regularization images / 正則化画像が見つかりませんでした") + else: + # num_repeatsを計算する:どうせ大した数ではないのでループで処理する + n = 0 + first_loop = True + while n < num_train_images: + for info, subset in reg_infos: + if first_loop: + self.register_image(info, subset) + n += info.num_repeats + else: + info.num_repeats += 1 # rewrite registered info + n += 1 + if n >= num_train_images: + break + first_loop = False + + self.num_reg_images = num_reg_images + + +class FineTuningDataset(BaseDataset): + def __init__( + self, + subsets: Sequence[FineTuningSubset], + batch_size: int, + resolution, + network_multiplier: float, + enable_bucket: bool, + min_bucket_reso: int, + max_bucket_reso: int, + bucket_reso_steps: int, + bucket_no_upscale: bool, + debug_dataset: bool, + ) -> None: + super().__init__(resolution, network_multiplier, debug_dataset) + + self.batch_size = batch_size + + self.num_train_images = 0 + self.num_reg_images = 0 + + for subset in subsets: + if subset.num_repeats < 1: + logger.warning( + f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}" + ) + continue + + if subset in self.subsets: + logger.warning( + f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します" + ) + continue + + # メタデータを読み込む + if os.path.exists(subset.metadata_file): + logger.info(f"loading existing metadata: {subset.metadata_file}") + with open(subset.metadata_file, "rt", encoding="utf-8") as f: + metadata = json.load(f) + else: + raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}") + + if len(metadata) < 1: + logger.warning( + f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します" + ) + continue + + tags_list = [] + for image_key, img_md in metadata.items(): + # path情報を作る + abs_path = None + + # まず画像を優先して探す + if os.path.exists(image_key): + abs_path = image_key + else: + # わりといい加減だがいい方法が思いつかん + paths = glob_images(subset.image_dir, image_key) + if len(paths) > 0: + abs_path = paths[0] + + # なければnpzを探す + if abs_path is None: + if os.path.exists(os.path.splitext(image_key)[0] + ".npz"): + abs_path = os.path.splitext(image_key)[0] + ".npz" + else: + npz_path = os.path.join(subset.image_dir, image_key + ".npz") + if os.path.exists(npz_path): + abs_path = npz_path + + assert abs_path is not None, f"no image / 画像がありません: {image_key}" + + caption = img_md.get("caption") + tags = img_md.get("tags") + if caption is None: + caption = tags # could be multiline + tags = None + + if subset.enable_wildcard: + # tags must be single line + if tags is not None: + tags = tags.replace("\n", subset.caption_separator) + + # add tags to each line of caption + if caption is not None and tags is not None: + caption = "\n".join( + [f"{line}{subset.caption_separator}{tags}" for line in caption.split("\n") if line.strip() != ""] + ) + else: + # use as is + if tags is not None and len(tags) > 0: + caption = caption + subset.caption_separator + tags + tags_list.append(tags) + + if caption is None: + caption = "" + + image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path) + image_info.image_size = img_md.get("train_resolution") + + if not subset.color_aug and not subset.random_crop: + # if npz exists, use them + image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key) + + self.register_image(image_info, subset) + + self.num_train_images += len(metadata) * subset.num_repeats + + # TODO do not record tag freq when no tag + self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list) + subset.img_count = len(metadata) + self.subsets.append(subset) + + # check existence of all npz files + use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets]) + if use_npz_latents: + flip_aug_in_subset = False + npz_any = False + npz_all = True + + for image_info in self.image_data.values(): + subset = self.image_to_subset[image_info.image_key] + + has_npz = image_info.latents_npz is not None + npz_any = npz_any or has_npz + + if subset.flip_aug: + has_npz = has_npz and image_info.latents_npz_flipped is not None + flip_aug_in_subset = True + npz_all = npz_all and has_npz + + if npz_any and not npz_all: + break + + if not npz_any: + use_npz_latents = False + logger.warning(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します") + elif not npz_all: + use_npz_latents = False + logger.warning( + f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します" + ) + if flip_aug_in_subset: + logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません") + # else: + # logger.info("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません") + + # check min/max bucket size + sizes = set() + resos = set() + for image_info in self.image_data.values(): + if image_info.image_size is None: + sizes = None # not calculated + break + sizes.add(image_info.image_size[0]) + sizes.add(image_info.image_size[1]) + resos.add(tuple(image_info.image_size)) + + if sizes is None: + if use_npz_latents: + use_npz_latents = False + logger.warning( + f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します" + ) + + assert ( + resolution is not None + ), "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください" + + self.enable_bucket = enable_bucket + if self.enable_bucket: + min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps( + resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps + ) + self.min_bucket_reso = min_bucket_reso + self.max_bucket_reso = max_bucket_reso + self.bucket_reso_steps = bucket_reso_steps + self.bucket_no_upscale = bucket_no_upscale + else: + if not enable_bucket: + logger.info("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします") + logger.info("using bucket info in metadata / メタデータ内のbucket情報を使います") + self.enable_bucket = True + + assert ( + not bucket_no_upscale + ), "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません" + + # bucket情報を初期化しておく、make_bucketsで再作成しない + self.bucket_manager = BucketManager(False, None, None, None, None) + self.bucket_manager.set_predefined_resos(resos) + + # npz情報をきれいにしておく + if not use_npz_latents: + for image_info in self.image_data.values(): + image_info.latents_npz = image_info.latents_npz_flipped = None + + def image_key_to_npz_file(self, subset: FineTuningSubset, image_key): + base_name = os.path.splitext(image_key)[0] + npz_file_norm = base_name + ".npz" + + if os.path.exists(npz_file_norm): + # image_key is full path + npz_file_flip = base_name + "_flip.npz" + if not os.path.exists(npz_file_flip): + npz_file_flip = None + return npz_file_norm, npz_file_flip + + # if not full path, check image_dir. if image_dir is None, return None + if subset.image_dir is None: + return None, None + + # image_key is relative path + npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz") + npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz") + + if not os.path.exists(npz_file_norm): + npz_file_norm = None + npz_file_flip = None + elif not os.path.exists(npz_file_flip): + npz_file_flip = None + + return npz_file_norm, npz_file_flip + + +class ControlNetDataset(BaseDataset): + def __init__( + self, + subsets: Sequence[ControlNetSubset], + batch_size: int, + resolution, + network_multiplier: float, + enable_bucket: bool, + min_bucket_reso: int, + max_bucket_reso: int, + bucket_reso_steps: int, + bucket_no_upscale: bool, + debug_dataset: float, + ) -> None: + super().__init__(resolution, network_multiplier, debug_dataset) + + db_subsets = [] + for subset in subsets: + assert ( + not subset.random_crop + ), "random_crop is not supported in ControlNetDataset / random_cropはControlNetDatasetではサポートされていません" + db_subset = DreamBoothSubset( + subset.image_dir, + False, + None, + subset.caption_extension, + subset.cache_info, + False, + subset.num_repeats, + subset.shuffle_caption, + subset.caption_separator, + subset.keep_tokens, + subset.keep_tokens_separator, + subset.secondary_separator, + subset.enable_wildcard, + subset.color_aug, + subset.flip_aug, + subset.face_crop_aug_range, + subset.random_crop, + subset.caption_dropout_rate, + subset.caption_dropout_every_n_epochs, + subset.caption_tag_dropout_rate, + subset.caption_prefix, + subset.caption_suffix, + subset.token_warmup_min, + subset.token_warmup_step, + ) + db_subsets.append(db_subset) + + self.dreambooth_dataset_delegate = DreamBoothDataset( + db_subsets, + batch_size, + resolution, + network_multiplier, + enable_bucket, + min_bucket_reso, + max_bucket_reso, + bucket_reso_steps, + bucket_no_upscale, + 1.0, + debug_dataset, + ) + + # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい) + self.image_data = self.dreambooth_dataset_delegate.image_data + self.batch_size = batch_size + self.num_train_images = self.dreambooth_dataset_delegate.num_train_images + self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images + + # assert all conditioning data exists + missing_imgs = [] + cond_imgs_with_pair = set() + for image_key, info in self.dreambooth_dataset_delegate.image_data.items(): + db_subset = self.dreambooth_dataset_delegate.image_to_subset[image_key] + subset = None + for s in subsets: + if s.image_dir == db_subset.image_dir: + subset = s + break + assert subset is not None, "internal error: subset not found" + + if not os.path.isdir(subset.conditioning_data_dir): + logger.warning(f"not directory: {subset.conditioning_data_dir}") + continue + + img_basename = os.path.splitext(os.path.basename(info.absolute_path))[0] + ctrl_img_path = glob_images(subset.conditioning_data_dir, img_basename) + if len(ctrl_img_path) < 1: + missing_imgs.append(img_basename) + continue + ctrl_img_path = ctrl_img_path[0] + ctrl_img_path = os.path.abspath(ctrl_img_path) # normalize path + + info.cond_img_path = ctrl_img_path + cond_imgs_with_pair.add(os.path.splitext(ctrl_img_path)[0]) # remove extension because Windows is case insensitive + + extra_imgs = [] + for subset in subsets: + conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*") + conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path + extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair]) + + assert ( + len(missing_imgs) == 0 + ), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}" + assert ( + len(extra_imgs) == 0 + ), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}" + + self.conditioning_image_transforms = IMAGE_TRANSFORMS + + def set_current_strategies(self): + return self.dreambooth_dataset_delegate.set_current_strategies() + + def make_buckets(self): + self.dreambooth_dataset_delegate.make_buckets() + self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager + self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices + + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): + return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) + + def new_cache_latents(self, model: Any, accelerator: Accelerator): + return self.dreambooth_dataset_delegate.new_cache_latents(model, accelerator) + + def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): + return self.dreambooth_dataset_delegate.new_cache_text_encoder_outputs(models, is_main_process) + + def __len__(self): + return self.dreambooth_dataset_delegate.__len__() + + def __getitem__(self, index): + example = self.dreambooth_dataset_delegate[index] + + bucket = self.dreambooth_dataset_delegate.bucket_manager.buckets[ + self.dreambooth_dataset_delegate.buckets_indices[index].bucket_index + ] + bucket_batch_size = self.dreambooth_dataset_delegate.buckets_indices[index].bucket_batch_size + image_index = self.dreambooth_dataset_delegate.buckets_indices[index].batch_index * bucket_batch_size + + conditioning_images = [] + + for i, image_key in enumerate(bucket[image_index : image_index + bucket_batch_size]): + image_info = self.dreambooth_dataset_delegate.image_data[image_key] + + target_size_hw = example["target_sizes_hw"][i] + original_size_hw = example["original_sizes_hw"][i] + crop_top_left = example["crop_top_lefts"][i] + flipped = example["flippeds"][i] + cond_img = load_image(image_info.cond_img_path) + + if self.dreambooth_dataset_delegate.enable_bucket: + assert ( + cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1] + ), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}" + cond_img = cv2.resize( + cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA + ) # INTER_AREAでやりたいのでcv2でリサイズ + + # TODO support random crop + # 現在サポートしているcropはrandomではなく中央のみ + h, w = target_size_hw + ct = (cond_img.shape[0] - h) // 2 + cl = (cond_img.shape[1] - w) // 2 + cond_img = cond_img[ct : ct + h, cl : cl + w] + else: + # assert ( + # cond_img.shape[0] == self.height and cond_img.shape[1] == self.width + # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" + # resize to target + if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: + cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0]))) + + if flipped: + cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride + + cond_img = self.conditioning_image_transforms(cond_img) + conditioning_images.append(cond_img) + + example["conditioning_images"] = torch.stack(conditioning_images).to(memory_format=torch.contiguous_format).float() + + return example + + +# behave as Dataset mock +class DatasetGroup(torch.utils.data.ConcatDataset): + def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]): + self.datasets: List[Union[DreamBoothDataset, FineTuningDataset]] + + super().__init__(datasets) + + self.image_data = {} + self.num_train_images = 0 + self.num_reg_images = 0 + + # simply concat together + # TODO: handling image_data key duplication among dataset + # In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset. + for dataset in datasets: + self.image_data.update(dataset.image_data) + self.num_train_images += dataset.num_train_images + self.num_reg_images += dataset.num_reg_images + + def add_replacement(self, str_from, str_to): + for dataset in self.datasets: + dataset.add_replacement(str_from, str_to) + + # def make_buckets(self): + # for dataset in self.datasets: + # dataset.make_buckets() + + def set_text_encoder_output_caching_strategy(self, strategy: TextEncoderOutputsCachingStrategy): + """ + DataLoader is run in multiple processes, so we need to set the strategy manually. + """ + for dataset in self.datasets: + dataset.set_text_encoder_output_caching_strategy(strategy) + + def enable_XTI(self, *args, **kwargs): + for dataset in self.datasets: + dataset.enable_XTI(*args, **kwargs) + + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): + for i, dataset in enumerate(self.datasets): + logger.info(f"[Dataset {i}]") + dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix) + + def new_cache_latents(self, model: Any, accelerator: Accelerator): + for i, dataset in enumerate(self.datasets): + logger.info(f"[Dataset {i}]") + dataset.new_cache_latents(model, accelerator) + accelerator.wait_for_everyone() + + def cache_text_encoder_outputs( + self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True + ): + for i, dataset in enumerate(self.datasets): + logger.info(f"[Dataset {i}]") + dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process) + + def cache_text_encoder_outputs_sd3( + self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None + ): + for i, dataset in enumerate(self.datasets): + logger.info(f"[Dataset {i}]") + dataset.cache_text_encoder_outputs_sd3( + tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size + ) + + def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator): + for i, dataset in enumerate(self.datasets): + logger.info(f"[Dataset {i}]") + dataset.new_cache_text_encoder_outputs(models, accelerator) + accelerator.wait_for_everyone() + + def set_caching_mode(self, caching_mode): + for dataset in self.datasets: + dataset.set_caching_mode(caching_mode) + + def verify_bucket_reso_steps(self, min_steps: int): + for dataset in self.datasets: + dataset.verify_bucket_reso_steps(min_steps) + + def get_resolutions(self) -> List[Tuple[int, int]]: + return [(dataset.width, dataset.height) for dataset in self.datasets] + + def is_latent_cacheable(self) -> bool: + return all([dataset.is_latent_cacheable() for dataset in self.datasets]) + + def is_text_encoder_output_cacheable(self) -> bool: + return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets]) + + def set_current_strategies(self): + for dataset in self.datasets: + dataset.set_current_strategies() + + def set_current_epoch(self, epoch): + for dataset in self.datasets: + dataset.set_current_epoch(epoch) + + def set_current_step(self, step): + for dataset in self.datasets: + dataset.set_current_step(step) + + def set_max_train_steps(self, max_train_steps): + for dataset in self.datasets: + dataset.set_max_train_steps(max_train_steps) + + def disable_token_padding(self): + for dataset in self.datasets: + dataset.disable_token_padding() + + +def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alpha_mask: bool): + expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意 + + if not os.path.exists(npz_path): + return False + + try: + npz = np.load(npz_path) + if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver? + return False + if npz["latents"].shape[1:3] != expected_latents_size: + return False + + if flip_aug: + if "latents_flipped" not in npz: + return False + if npz["latents_flipped"].shape[1:3] != expected_latents_size: + return False + + if alpha_mask: + if "alpha_mask" not in npz: + return False + if (npz["alpha_mask"].shape[1], npz["alpha_mask"].shape[0]) != reso: # HxW => WxH != reso + return False + else: + if "alpha_mask" in npz: + return False + except Exception as e: + logger.error(f"Error loading file: {npz_path}") + raise e + + return True + + +# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) +# TODO update to use CachingStrategy +# def load_latents_from_disk( +# npz_path, +# ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: +# npz = np.load(npz_path) +# if "latents" not in npz: +# raise ValueError(f"error: npz is old format. please re-generate {npz_path}") + +# latents = npz["latents"] +# original_size = npz["original_size"].tolist() +# crop_ltrb = npz["crop_ltrb"].tolist() +# flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None +# alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None +# return latents, original_size, crop_ltrb, flipped_latents, alpha_mask + + +# def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None): +# kwargs = {} +# if flipped_latents_tensor is not None: +# kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() +# if alpha_mask is not None: +# kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() +# np.savez( +# npz_path, +# latents=latents_tensor.float().cpu().numpy(), +# original_size=np.array(original_size), +# crop_ltrb=np.array(crop_ltrb), +# **kwargs, +# ) + + +def debug_dataset(train_dataset, show_input_ids=False): + logger.info(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") + logger.info( + "`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します" + ) + + epoch = 1 + while True: + logger.info(f"") + logger.info(f"epoch: {epoch}") + + steps = (epoch - 1) * len(train_dataset) + 1 + indices = list(range(len(train_dataset))) + random.shuffle(indices) + + k = 0 + for i, idx in enumerate(indices): + train_dataset.set_current_epoch(epoch) + train_dataset.set_current_step(steps) + logger.info(f"steps: {steps} ({i + 1}/{len(train_dataset)})") + + example = train_dataset[idx] + if example["latents"] is not None: + logger.info(f"sample has latents from npz file: {example['latents'].size()}") + for j, (ik, cap, lw, orgsz, crptl, trgsz, flpdz) in enumerate( + zip( + example["image_keys"], + example["captions"], + example["loss_weights"], + # example["input_ids"], + example["original_sizes_hw"], + example["crop_top_lefts"], + example["target_sizes_hw"], + example["flippeds"], + ) + ): + logger.info( + f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}' + ) + if "network_multipliers" in example: + logger.info(f"network multiplier: {example['network_multipliers'][j]}") + if "custom_attributes" in example: + logger.info(f"custom attributes: {example['custom_attributes'][j]}") + + # if show_input_ids: + # logger.info(f"input ids: {iid}") + # if "input_ids2" in example: + # logger.info(f"input ids2: {example['input_ids2'][j]}") + if example["images"] is not None: + im = example["images"][j] + logger.info(f"image size: {im.size()}") + im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) + im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c + im = im[:, :, ::-1] # RGB -> BGR (OpenCV) + + if "conditioning_images" in example: + cond_img = example["conditioning_images"][j] + logger.info(f"conditioning image size: {cond_img.size()}") + cond_img = ((cond_img.numpy() + 1.0) * 127.5).astype(np.uint8) + cond_img = np.transpose(cond_img, (1, 2, 0)) + cond_img = cond_img[:, :, ::-1] + if os.name == "nt": + cv2.imshow("cond_img", cond_img) + + if "alpha_masks" in example and example["alpha_masks"] is not None: + alpha_mask = example["alpha_masks"][j] + logger.info(f"alpha mask size: {alpha_mask.size()}") + alpha_mask = (alpha_mask.numpy() * 255.0).astype(np.uint8) + if os.name == "nt": + cv2.imshow("alpha_mask", alpha_mask) + + if os.name == "nt": # only windows + cv2.imshow("img", im) + k = cv2.waitKey() + cv2.destroyAllWindows() + if k == 27 or k == ord("s") or k == ord("e"): + break + steps += 1 + + if k == ord("e"): + break + if k == 27 or (example["images"] is None and i >= 8): + k = 27 + break + if k == 27: + break + + epoch += 1 + + +def glob_images(directory, base="*"): + img_paths = [] + for ext in IMAGE_EXTENSIONS: + if base == "*": + img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) + else: + img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) + img_paths = list(set(img_paths)) # 重複を排除 + img_paths.sort() + return img_paths + + +def glob_images_pathlib(dir_path, recursive): + image_paths = [] + if recursive: + for ext in IMAGE_EXTENSIONS: + image_paths += list(dir_path.rglob("*" + ext)) + else: + for ext in IMAGE_EXTENSIONS: + image_paths += list(dir_path.glob("*" + ext)) + image_paths = list(set(image_paths)) # 重複を排除 + image_paths.sort() + return image_paths + + +class MinimalDataset(BaseDataset): + def __init__(self, resolution, network_multiplier, debug_dataset=False): + super().__init__(resolution, network_multiplier, debug_dataset) + + self.num_train_images = 0 # update in subclass + self.num_reg_images = 0 # update in subclass + self.datasets = [self] + self.batch_size = 1 # update in subclass + + self.subsets = [self] + self.num_repeats = 1 # update in subclass if needed + self.img_count = 1 # update in subclass if needed + self.bucket_info = {} + self.is_reg = False + self.image_dir = "dummy" # for metadata + + def verify_bucket_reso_steps(self, min_steps: int): + pass + + def is_latent_cacheable(self) -> bool: + return False + + def __len__(self): + raise NotImplementedError + + # override to avoid shuffling buckets + def set_current_epoch(self, epoch): + self.current_epoch = epoch + + def __getitem__(self, idx): + r""" + The subclass may have image_data for debug_dataset, which is a dict of ImageInfo objects. + + Returns: example like this: + + for i in range(batch_size): + image_key = ... # whatever hashable + image_keys.append(image_key) + + image = ... # PIL Image + img_tensor = self.image_transforms(img) + images.append(img_tensor) + + caption = ... # str + input_ids = self.get_input_ids(caption) + input_ids_list.append(input_ids) + + captions.append(caption) + + images = torch.stack(images, dim=0) + input_ids_list = torch.stack(input_ids_list, dim=0) + example = { + "images": images, + "input_ids": input_ids_list, + "captions": captions, # for debug_dataset + "latents": None, + "image_keys": image_keys, # for debug_dataset + "loss_weights": torch.ones(batch_size, dtype=torch.float32), + } + return example + """ + raise NotImplementedError + + +def load_arbitrary_dataset(args, tokenizer=None) -> MinimalDataset: + module = ".".join(args.dataset_class.split(".")[:-1]) + dataset_class = args.dataset_class.split(".")[-1] + module = importlib.import_module(module) + dataset_class = getattr(module, dataset_class) + train_dataset_group: MinimalDataset = dataset_class(tokenizer, args.max_token_length, args.resolution, args.debug_dataset) + return train_dataset_group + + +def load_image(image_path, alpha=False): + try: + with Image.open(image_path) as image: + if alpha: + if not image.mode == "RGBA": + image = image.convert("RGBA") + else: + if not image.mode == "RGB": + image = image.convert("RGB") + img = np.array(image, np.uint8) + return img + except (IOError, OSError) as e: + logger.error(f"Error loading file: {image_path}") + raise e + + +# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom) +def trim_and_resize_if_required( + random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int] +) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]: + image_height, image_width = image.shape[0:2] + original_size = (image_width, image_height) # size before resize + + if image_width != resized_size[0] or image_height != resized_size[1]: + # リサイズする + if image_width > resized_size[0] and image_height > resized_size[1]: + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + else: + image = pil_resize(image, resized_size) + + image_height, image_width = image.shape[0:2] + + if image_width > reso[0]: + trim_size = image_width - reso[0] + p = trim_size // 2 if not random_crop else random.randint(0, trim_size) + # logger.info(f"w {trim_size} {p}") + image = image[:, p : p + reso[0]] + if image_height > reso[1]: + trim_size = image_height - reso[1] + p = trim_size // 2 if not random_crop else random.randint(0, trim_size) + # logger.info(f"h {trim_size} {p}) + image = image[p : p + reso[1]] + + # random cropの場合のcropされた値をどうcrop left/topに反映するべきか全くアイデアがない + # I have no idea how to reflect the cropped value in crop left/top in the case of random crop + + crop_ltrb = BucketManager.get_crop_ltrb(reso, original_size) + + assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}" + return image, original_size, crop_ltrb + + +# for new_cache_latents +def load_images_and_masks_for_caching( + image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool +) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]: + r""" + requires image_infos to have: [absolute_path or image], bucket_reso, resized_size + + returns: image_tensor, alpha_masks, original_sizes, crop_ltrbs + + image_tensor: torch.Tensor = torch.Size([B, 3, H, W]), ...], normalized to [-1, 1] + alpha_masks: List[np.ndarray] = [np.ndarray([H, W]), ...], normalized to [0, 1] + original_sizes: List[Tuple[int, int]] = [(W, H), ...] + crop_ltrbs: List[Tuple[int, int, int, int]] = [(L, T, R, B), ...] + """ + images: List[torch.Tensor] = [] + alpha_masks: List[np.ndarray] = [] + original_sizes: List[Tuple[int, int]] = [] + crop_ltrbs: List[Tuple[int, int, int, int]] = [] + for info in image_infos: + image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) + # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 + image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) + + original_sizes.append(original_size) + crop_ltrbs.append(crop_ltrb) + + if use_alpha_mask: + if image.shape[2] == 4: + alpha_mask = image[:, :, 3] # [H,W] + alpha_mask = alpha_mask.astype(np.float32) / 255.0 + alpha_mask = torch.FloatTensor(alpha_mask) # [H,W] + else: + alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W] + else: + alpha_mask = None + alpha_masks.append(alpha_mask) + + image = image[:, :, :3] # remove alpha channel if exists + image = IMAGE_TRANSFORMS(image) + images.append(image) + + img_tensor = torch.stack(images, dim=0) + return img_tensor, alpha_masks, original_sizes, crop_ltrbs + + +def cache_batch_latents( + vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool +) -> None: + r""" + requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz + optionally requires image_infos to have: image + if cache_to_disk is True, set info.latents_npz + flipped latents is also saved if flip_aug is True + if cache_to_disk is False, set info.latents + latents_flipped is also set if flip_aug is True + latents_original_size and latents_crop_ltrb are also set + """ + images = [] + alpha_masks: List[np.ndarray] = [] + for info in image_infos: + image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) + # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 + image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) + + info.latents_original_size = original_size + info.latents_crop_ltrb = crop_ltrb + + if use_alpha_mask: + if image.shape[2] == 4: + alpha_mask = image[:, :, 3] # [H,W] + alpha_mask = alpha_mask.astype(np.float32) / 255.0 + alpha_mask = torch.FloatTensor(alpha_mask) # [H,W] + else: + alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W] + else: + alpha_mask = None + alpha_masks.append(alpha_mask) + + image = image[:, :, :3] # remove alpha channel if exists + image = IMAGE_TRANSFORMS(image) + images.append(image) + + img_tensors = torch.stack(images, dim=0) + img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) + + with torch.no_grad(): + latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + + if flip_aug: + img_tensors = torch.flip(img_tensors, dims=[3]) + with torch.no_grad(): + flipped_latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + else: + flipped_latents = [None] * len(latents) + + for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks): + # check NaN + if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()): + raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") + + if cache_to_disk: + # save_latents_to_disk( + # info.latents_npz, + # latent, + # info.latents_original_size, + # info.latents_crop_ltrb, + # flipped_latent, + # alpha_mask, + # ) + pass + else: + info.latents = latent + if flip_aug: + info.latents_flipped = flipped_latent + info.alpha_mask = alpha_mask + + if not HIGH_VRAM: + clean_memory_on_device(vae.device) + + +def cache_batch_text_encoder_outputs( + image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype +): + input_ids1 = input_ids1.to(text_encoders[0].device) + input_ids2 = input_ids2.to(text_encoders[1].device) + + with torch.no_grad(): + b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl( + max_token_length, + input_ids1, + input_ids2, + tokenizers[0], + tokenizers[1], + text_encoders[0], + text_encoders[1], + dtype, + ) + + # ここでcpuに移動しておかないと、上書きされてしまう + b_hidden_state1 = b_hidden_state1.detach().to("cpu") # b,n*75+2,768 + b_hidden_state2 = b_hidden_state2.detach().to("cpu") # b,n*75+2,1280 + b_pool2 = b_pool2.detach().to("cpu") # b,1280 + + for info, hidden_state1, hidden_state2, pool2 in zip(image_infos, b_hidden_state1, b_hidden_state2, b_pool2): + if cache_to_disk: + save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, hidden_state1, hidden_state2, pool2) + else: + info.text_encoder_outputs1 = hidden_state1 + info.text_encoder_outputs2 = hidden_state2 + info.text_encoder_pool2 = pool2 + + +def cache_batch_text_encoder_outputs_sd3( + image_infos, tokenizer, text_encoders, max_token_length, cache_to_disk, input_ids, output_dtype +): + # make input_ids for each text encoder + l_tokens, g_tokens, t5_tokens = input_ids + + clip_l, clip_g, t5xxl = text_encoders + with torch.no_grad(): + b_lg_out, b_t5_out, b_pool = sd3_utils.get_cond_from_tokens( + l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, "cpu", output_dtype + ) + b_lg_out = b_lg_out.detach() + b_t5_out = b_t5_out.detach() + b_pool = b_pool.detach() + + for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool): + # debug: NaN check + if torch.isnan(lg_out).any() or torch.isnan(t5_out).any() or torch.isnan(pool).any(): + raise RuntimeError(f"NaN detected in text encoder outputs: {info.absolute_path}") + + if cache_to_disk: + save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool) + else: + info.text_encoder_outputs1 = lg_out + info.text_encoder_outputs2 = t5_out + info.text_encoder_pool2 = pool + + +def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2): + np.savez( + npz_path, + hidden_state1=hidden_state1.cpu().float().numpy(), + hidden_state2=hidden_state2.cpu().float().numpy(), + pool2=pool2.cpu().float().numpy(), + ) + + +def load_text_encoder_outputs_from_disk(npz_path): + with np.load(npz_path) as f: + hidden_state1 = torch.from_numpy(f["hidden_state1"]) + hidden_state2 = torch.from_numpy(f["hidden_state2"]) if "hidden_state2" in f else None + pool2 = torch.from_numpy(f["pool2"]) if "pool2" in f else None + return hidden_state1, hidden_state2, pool2 + + +# endregion + +# region モジュール入れ替え部 +""" +高速化のためのモジュール入れ替え +""" + +# FlashAttentionを使うCrossAttention +# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py +# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE + +# constants + +EPSILON = 1e-6 + +# helper functions + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +def model_hash(filename): + """Old model hash used by stable-diffusion-webui""" + try: + with open(filename, "rb") as file: + m = hashlib.sha256() + + file.seek(0x100000) + m.update(file.read(0x10000)) + return m.hexdigest()[0:8] + except FileNotFoundError: + return "NOFILE" + except IsADirectoryError: # Linux? + return "IsADirectory" + except PermissionError: # Windows + return "IsADirectory" + + +def calculate_sha256(filename): + """New model hash used by stable-diffusion-webui""" + try: + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + except FileNotFoundError: + return "NOFILE" + except IsADirectoryError: # Linux? + return "IsADirectory" + except PermissionError: # Windows + return "IsADirectory" + + +def precalculate_safetensors_hashes(tensors, metadata): + """Precalculate the model hashes needed by sd-webui-additional-networks to + save time on indexing the model later.""" + + # Because writing user metadata to the file can change the result of + # sd_models.model_hash(), only retain the training metadata for purposes of + # calculating the hash, as they are meant to be immutable + metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} + + bytes = safetensors.torch.save(tensors, metadata) + b = BytesIO(bytes) + + model_hash = addnet_hash_safetensors(b) + legacy_hash = addnet_hash_legacy(b) + return model_hash, legacy_hash + + +def addnet_hash_legacy(b): + """Old model hash used by sd-webui-additional-networks for .safetensors format files""" + m = hashlib.sha256() + + b.seek(0x100000) + m.update(b.read(0x10000)) + return m.hexdigest()[0:8] + + +def addnet_hash_safetensors(b): + """New model hash used by sd-webui-additional-networks for .safetensors format files""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") + + offset = n + 8 + b.seek(offset) + for chunk in iter(lambda: b.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + + +def get_git_revision_hash() -> str: + try: + return subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=os.path.dirname(__file__)).decode("ascii").strip() + except: + return "(unknown)" + + +# def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): +# replace_attentions_for_hypernetwork() +# # unet is not used currently, but it is here for future use +# unet.enable_xformers_memory_efficient_attention() +# return +# if mem_eff_attn: +# unet.set_attn_processor(FlashAttnProcessor()) +# elif xformers: +# unet.enable_xformers_memory_efficient_attention() + + +# def replace_unet_cross_attn_to_xformers(): +# logger.info("CrossAttention.forward has been replaced to enable xformers.") +# try: +# import xformers.ops +# except ImportError: +# raise ImportError("No xformers / xformersがインストールされていないようです") + +# def forward_xformers(self, x, context=None, mask=None): +# h = self.heads +# q_in = self.to_q(x) + +# context = default(context, x) +# context = context.to(x.dtype) + +# if hasattr(self, "hypernetwork") and self.hypernetwork is not None: +# context_k, context_v = self.hypernetwork.forward(x, context) +# context_k = context_k.to(x.dtype) +# context_v = context_v.to(x.dtype) +# else: +# context_k = context +# context_v = context + +# k_in = self.to_k(context_k) +# v_in = self.to_v(context_v) + +# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) +# del q_in, k_in, v_in + +# q = q.contiguous() +# k = k.contiguous() +# v = v.contiguous() +# out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる + +# out = rearrange(out, "b n h d -> b n (h d)", h=h) + +# # diffusers 0.7.0~ +# out = self.to_out[0](out) +# out = self.to_out[1](out) +# return out + + +# diffusers.models.attention.CrossAttention.forward = forward_xformers +def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdpa): + if mem_eff_attn: + logger.info("Enable memory efficient attention for U-Net") + unet.set_use_memory_efficient_attention(False, True) + elif xformers: + logger.info("Enable xformers for U-Net") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") + + unet.set_use_memory_efficient_attention(True, False) + elif sdpa: + logger.info("Enable SDPA for U-Net") + unet.set_use_sdpa(True) + + +""" +def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers): + # vae is not used currently, but it is here for future use + if mem_eff_attn: + replace_vae_attn_to_memory_efficient() + elif xformers: + # とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ + logger.info("Use Diffusers xformers for VAE") + vae.encoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) + vae.decoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) + + +def replace_vae_attn_to_memory_efficient(): + logger.info("AttentionBlock.forward has been replaced to FlashAttention (not xformers)") + flash_func = FlashAttentionFunction + + def forward_flash_attn(self, hidden_states): + logger.info("forward_flash_attn") + q_bucket_size = 512 + k_bucket_size = 1024 + + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj) + ) + + out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + hidden_states = self.proj_attn(hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + diffusers.models.attention.AttentionBlock.forward = forward_flash_attn +""" + + +# endregion + + +# region arguments + + +def load_metadata_from_safetensors(safetensors_file: str) -> dict: + """r + This method locks the file. see https://github.com/huggingface/safetensors/issues/164 + If the file isn't .safetensors or doesn't have metadata, return empty dict. + """ + if os.path.splitext(safetensors_file)[1] != ".safetensors": + return {} + + with safetensors.safe_open(safetensors_file, framework="pt", device="cpu") as f: + metadata = f.metadata() + if metadata is None: + metadata = {} + return metadata + + +# this metadata is referred from train_network and various scripts, so we wrote here +SS_METADATA_KEY_V2 = "ss_v2" +SS_METADATA_KEY_BASE_MODEL_VERSION = "ss_base_model_version" +SS_METADATA_KEY_NETWORK_MODULE = "ss_network_module" +SS_METADATA_KEY_NETWORK_DIM = "ss_network_dim" +SS_METADATA_KEY_NETWORK_ALPHA = "ss_network_alpha" +SS_METADATA_KEY_NETWORK_ARGS = "ss_network_args" + +SS_METADATA_MINIMUM_KEYS = [ + SS_METADATA_KEY_V2, + SS_METADATA_KEY_BASE_MODEL_VERSION, + SS_METADATA_KEY_NETWORK_MODULE, + SS_METADATA_KEY_NETWORK_DIM, + SS_METADATA_KEY_NETWORK_ALPHA, + SS_METADATA_KEY_NETWORK_ARGS, +] + + +def build_minimum_network_metadata( + v2: Optional[str], + base_model: Optional[str], + network_module: str, + network_dim: str, + network_alpha: str, + network_args: Optional[dict], +): + # old LoRA doesn't have base_model + metadata = { + SS_METADATA_KEY_NETWORK_MODULE: network_module, + SS_METADATA_KEY_NETWORK_DIM: network_dim, + SS_METADATA_KEY_NETWORK_ALPHA: network_alpha, + } + if v2 is not None: + metadata[SS_METADATA_KEY_V2] = v2 + if base_model is not None: + metadata[SS_METADATA_KEY_BASE_MODEL_VERSION] = base_model + if network_args is not None: + metadata[SS_METADATA_KEY_NETWORK_ARGS] = json.dumps(network_args) + return metadata + + +def get_sai_model_spec( + state_dict: dict, + args: argparse.Namespace, + sdxl: bool, + lora: bool, + textual_inversion: bool, + is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA + sd3: str = None, + flux: str = None, +): + timestamp = time.time() + + v2 = args.v2 + v_parameterization = args.v_parameterization + reso = args.resolution + + title = args.metadata_title if args.metadata_title is not None else args.output_name + + if args.min_timestep is not None or args.max_timestep is not None: + min_time_step = args.min_timestep if args.min_timestep is not None else 0 + max_time_step = args.max_timestep if args.max_timestep is not None else 1000 + timesteps = (min_time_step, max_time_step) + else: + timesteps = None + + metadata = sai_model_spec.build_metadata( + state_dict, + v2, + v_parameterization, + sdxl, + lora, + textual_inversion, + timestamp, + title=title, + reso=reso, + is_stable_diffusion_ckpt=is_stable_diffusion_ckpt, + author=args.metadata_author, + description=args.metadata_description, + license=args.metadata_license, + tags=args.metadata_tags, + timesteps=timesteps, + clip_skip=args.clip_skip, # None or int + sd3=sd3, + flux=flux, + ) + return metadata + + +def add_sd_models_arguments(parser: argparse.ArgumentParser): + # for pretrained models + parser.add_argument( + "--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む" + ) + parser.add_argument( + "--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする" + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル", + ) + parser.add_argument( + "--tokenizer_cache_dir", + type=str, + default=None, + help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", + ) + + +def add_optimizer_arguments(parser: argparse.ArgumentParser): + def int_or_float(value): + if value.endswith("%"): + try: + return float(value[:-1]) / 100.0 + except ValueError: + raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage") + try: + float_value = float(value) + if float_value >= 1: + return int(value) + return float(value) + except ValueError: + raise argparse.ArgumentTypeError(f"'{value}' is not an int or float") + + parser.add_argument( + "--optimizer_type", + type=str, + default="", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", + ) + + # backward compatibility + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)", + ) + parser.add_argument( + "--use_lion_optimizer", + action="store_true", + help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)", + ) + + parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") + parser.add_argument( + "--max_grad_norm", + default=1.0, + type=float, + help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない", + ) + + parser.add_argument( + "--optimizer_args", + type=str, + default=None, + nargs="*", + help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")', + ) + + parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ") + parser.add_argument( + "--lr_scheduler_args", + type=str, + default=None, + nargs="*", + help='additional arguments for scheduler (like "T_max=100") / スケジューラの追加引数(例: "T_max100")', + ) + + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor", + ) + parser.add_argument( + "--lr_warmup_steps", + type=int_or_float, + default=0, + help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps" + " / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)", + ) + parser.add_argument( + "--lr_decay_steps", + type=int_or_float, + default=0, + help="Int number of steps for the decay in the lr scheduler (default is 0) or float (<1) with ratio of train steps" + " / 学習率のスケジューラを減衰させるステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)", + ) + parser.add_argument( + "--lr_scheduler_num_cycles", + type=int, + default=1, + help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数", + ) + parser.add_argument( + "--lr_scheduler_power", + type=float, + default=1, + help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", + ) + parser.add_argument( + "--fused_backward_pass", + action="store_true", + help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL" + + " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効", + ) + parser.add_argument( + "--lr_scheduler_timescale", + type=int, + default=None, + help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`" + + " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", + ) + parser.add_argument( + "--lr_scheduler_min_lr_ratio", + type=float, + default=None, + help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler" + + " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", + ) + + +def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): + parser.add_argument( + "--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ" + ) + parser.add_argument( + "--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名" + ) + parser.add_argument( + "--huggingface_repo_id", + type=str, + default=None, + help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名", + ) + parser.add_argument( + "--huggingface_repo_type", + type=str, + default=None, + help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類", + ) + parser.add_argument( + "--huggingface_path_in_repo", + type=str, + default=None, + help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス", + ) + parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン") + parser.add_argument( + "--huggingface_repo_visibility", + type=str, + default=None, + help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定('public'で公開、'private'またはNoneで非公開)", + ) + parser.add_argument( + "--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する" + ) + parser.add_argument( + "--resume_from_huggingface", + action="store_true", + help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})", + ) + parser.add_argument( + "--async_upload", + action="store_true", + help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする", + ) + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving / 保存時に精度を変更して保存する", + ) + parser.add_argument( + "--save_every_n_epochs", + type=int, + default=None, + help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する", + ) + parser.add_argument( + "--save_every_n_steps", + type=int, + default=None, + help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する", + ) + parser.add_argument( + "--save_n_epoch_ratio", + type=int, + default=None, + help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)", + ) + parser.add_argument( + "--save_last_n_epochs", + type=int, + default=None, + help="save last N checkpoints when saving every N epochs (remove older checkpoints) / 指定エポックごとにモデルを保存するとき最大Nエポック保存する(古いチェックポイントは削除する)", + ) + parser.add_argument( + "--save_last_n_epochs_state", + type=int, + default=None, + help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きする)", + ) + parser.add_argument( + "--save_last_n_steps", + type=int, + default=None, + help="save checkpoints until N steps elapsed (remove older checkpoints if N steps elapsed) / 指定ステップごとにモデルを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する)", + ) + parser.add_argument( + "--save_last_n_steps_state", + type=int, + default=None, + help="save states until N steps elapsed (remove older states if N steps elapsed, overrides --save_last_n_steps) / 指定ステップごとにstateを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する。--save_last_n_stepsを上書きする)", + ) + parser.add_argument( + "--save_state", + action="store_true", + help="save training state additionally (including optimizer states etc.) when saving model / optimizerなど学習状態も含めたstateをモデル保存時に追加で保存する", + ) + parser.add_argument( + "--save_state_on_train_end", + action="store_true", + help="save training state (including optimizer states etc.) on train end / optimizerなど学習状態も含めたstateを学習完了時に保存する", + ) + parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") + + parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ") + parser.add_argument( + "--max_token_length", + type=int, + default=None, + choices=[None, 150, 225], + help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)", + ) + parser.add_argument( + "--mem_eff_attn", + action="store_true", + help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う", + ) + parser.add_argument( + "--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う" + ) + parser.add_argument( + "--dynamo_backend", + type=str, + default="inductor", + # available backends: + # https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5 + # https://pytorch.org/docs/stable/torch.compiler.html + choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"], + help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)", + ) + parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") + parser.add_argument( + "--sdpa", + action="store_true", + help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)", + ) + parser.add_argument( + "--vae", + type=str, + default=None, + help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ", + ) + + parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") + parser.add_argument( + "--max_train_epochs", + type=int, + default=None, + help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)", + ) + parser.add_argument( + "--max_data_loader_n_workers", + type=int, + default=8, + help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)", + ) + parser.add_argument( + "--persistent_data_loader_workers", + action="store_true", + help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)", + ) + parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") + parser.add_argument( + "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする" + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数", + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help="use mixed precision / 混合精度を使う場合、その精度", + ) + parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") + parser.add_argument( + "--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する" + ) # TODO move to SDXL training, because it is not supported by SD1/2 + parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う") + parser.add_argument( + "--fp8_dtype", + type=str, + default="e4m3", + choices=["e4m3", "e5m2"], + help="fp8 dtype selection", + ) + + parser.add_argument( + "--ddp_timeout", + type=int, + default=None, + help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)", + ) + parser.add_argument( + "--ddp_gradient_as_bucket_view", + action="store_true", + help="enable gradient_as_bucket_view for DDP / DDPでgradient_as_bucket_viewを有効にする", + ) + parser.add_argument( + "--ddp_static_graph", + action="store_true", + help="enable static_graph for DDP / DDPでstatic_graphを有効にする", + ) + parser.add_argument( + "--clip_skip", + type=int, + default=None, + help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)", + ) + parser.add_argument( + "--logging_dir", + type=str, + default=None, + help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する", + ) + parser.add_argument( + "--log_with", + type=str, + default=None, + choices=["tensorboard", "wandb", "all"], + help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)", + ) + parser.add_argument( + "--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列" + ) + parser.add_argument( + "--log_tracker_name", + type=str, + default=None, + help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名", + ) + parser.add_argument( + "--wandb_run_name", + type=str, + default=None, + help="The name of the specific wandb session / wandb ログに表示される特定の実行の名前", + ) + parser.add_argument( + "--log_tracker_config", + type=str, + default=None, + help="path to tracker config file to use for logging / ログ出力に使用するtrackerの設定ファイルのパス", + ) + parser.add_argument( + "--wandb_api_key", + type=str, + default=None, + help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)", + ) + parser.add_argument("--log_config", action="store_true", help="log training configuration / 学習設定をログに出力する") + + parser.add_argument( + "--noise_offset", + type=float, + default=None, + help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)", + ) + parser.add_argument( + "--noise_offset_random_strength", + action="store_true", + help="use random strength between 0~noise_offset for noise offset. / noise offsetにおいて、0からnoise_offsetの間でランダムな強度を使用します。", + ) + parser.add_argument( + "--multires_noise_iterations", + type=int, + default=None, + help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended) / Multires noiseを有効にしてこのイテレーション数を設定する(有効にする場合は6-10程度を推奨)", + ) + parser.add_argument( + "--ip_noise_gamma", + type=float, + default=None, + help="enable input perturbation noise. used for regularization. recommended value: around 0.1 (from arxiv.org/abs/2301.11706) " + + "/ input perturbation noiseを有効にする。正則化に使用される。推奨値: 0.1程度 (arxiv.org/abs/2301.11706 より)", + ) + parser.add_argument( + "--ip_noise_gamma_random_strength", + action="store_true", + help="Use random strength between 0~ip_noise_gamma for input perturbation noise." + + "/ input perturbation noiseにおいて、0からip_noise_gammaの間でランダムな強度を使用します。", + ) + # parser.add_argument( + # "--perlin_noise", + # type=int, + # default=None, + # help="enable perlin noise and set the octaves / perlin noiseを有効にしてoctavesをこの値に設定する", + # ) + parser.add_argument( + "--multires_noise_discount", + type=float, + default=0.3, + help="set discount value for multires noise (has no effect without --multires_noise_iterations) / Multires noiseのdiscount値を設定する(--multires_noise_iterations指定時のみ有効)", + ) + parser.add_argument( + "--adaptive_noise_scale", + type=float, + default=None, + help="add `latent mean absolute value * this value` to noise_offset (disabled if None, default) / latentの平均値の絶対値 * この値をnoise_offsetに加算する(Noneの場合は無効、デフォルト)", + ) + parser.add_argument( + "--zero_terminal_snr", + action="store_true", + help="fix noise scheduler betas to enforce zero terminal SNR / noise schedulerのbetasを修正して、zero terminal SNRを強制する", + ) + parser.add_argument( + "--min_timestep", + type=int, + default=None, + help="set minimum time step for U-Net training (0~999, default is 0) / U-Net学習時のtime stepの最小値を設定する(0~999で指定、省略時はデフォルト値(0)) ", + ) + parser.add_argument( + "--max_timestep", + type=int, + default=None, + help="set maximum time step for U-Net training (1~1000, default is 1000) / U-Net学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))", + ) + parser.add_argument( + "--loss_type", + type=str, + default="l2", + choices=["l1", "l2", "huber", "smooth_l1"], + help="The type of loss function to use (L1, L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L1、L2、Huber、またはsmooth L1)、デフォルトはL2", + ) + parser.add_argument( + "--huber_schedule", + type=str, + default="snr", + choices=["constant", "exponential", "snr"], + help="The scheduling method for Huber loss (constant, exponential, or SNR-based). Only used when loss_type is 'huber' or 'smooth_l1'. default is snr" + + " / Huber損失のスケジューリング方法(constant、exponential、またはSNRベース)。loss_typeが'huber'または'smooth_l1'の場合に有効、デフォルトは snr", + ) + parser.add_argument( + "--huber_c", + type=float, + default=0.1, + help="The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1" + " / Huber損失の減衰パラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1", + ) + + parser.add_argument( + "--huber_scale", + type=float, + default=1.0, + help="The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0" + " / Huber損失のスケールパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは1.0", + ) + + parser.add_argument( + "--lowram", + action="store_true", + help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込む等(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)", + ) + parser.add_argument( + "--highvram", + action="store_true", + help="disable low VRAM optimization. e.g. do not clear CUDA cache after each latent caching (for machines which have bigger VRAM) " + + "/ VRAMが少ない環境向け最適化を無効にする。たとえば各latentのキャッシュ後のCUDAキャッシュクリアを行わない等(VRAMが多い環境向け)", + ) + + parser.add_argument( + "--sample_every_n_steps", + type=int, + default=None, + help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する", + ) + parser.add_argument( + "--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する" + ) + parser.add_argument( + "--sample_every_n_epochs", + type=int, + default=None, + help="generate sample images every N epochs (overwrites n_steps) / 学習中のモデルで指定エポックごとにサンプル出力する(ステップ数指定を上書きします)", + ) + parser.add_argument( + "--sample_prompts", + type=str, + default=None, + help="file for prompts to generate sample images / 学習中モデルのサンプル出力用プロンプトのファイル", + ) + parser.add_argument( + "--sample_sampler", + type=str, + default="ddim", + choices=[ + "ddim", + "pndm", + "lms", + "euler", + "euler_a", + "heun", + "dpm_2", + "dpm_2_a", + "dpmsolver", + "dpmsolver++", + "dpmsingle", + "k_lms", + "k_euler", + "k_euler_a", + "k_dpm_2", + "k_dpm_2_a", + ], + help=f"sampler (scheduler) type for sample images / サンプル出力時のサンプラー(スケジューラ)の種類", + ) + + parser.add_argument( + "--config_file", + type=str, + default=None, + help="using .toml instead of args to pass hyperparameter / ハイパーパラメータを引数ではなく.tomlファイルで渡す", + ) + parser.add_argument( + "--output_config", action="store_true", help="output command line args to given .toml file / 引数を.tomlファイルに出力する" + ) + + # SAI Model spec + parser.add_argument( + "--metadata_title", + type=str, + default=None, + help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name", + ) + parser.add_argument( + "--metadata_author", + type=str, + default=None, + help="author name for model metadata / メタデータに書き込まれるモデル作者名", + ) + parser.add_argument( + "--metadata_description", + type=str, + default=None, + help="description for model metadata / メタデータに書き込まれるモデル説明", + ) + parser.add_argument( + "--metadata_license", + type=str, + default=None, + help="license for model metadata / メタデータに書き込まれるモデルライセンス", + ) + parser.add_argument( + "--metadata_tags", + type=str, + default=None, + help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り", + ) + + if support_dreambooth: + # DreamBooth training + parser.add_argument( + "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み" + ) + + +def add_masked_loss_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) + parser.add_argument( + "--masked_loss", + action="store_true", + help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要", + ) + + +def add_dit_training_arguments(parser: argparse.ArgumentParser): + # Text encoder related arguments + parser.add_argument( + "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" + ) + parser.add_argument( + "--cache_text_encoder_outputs_to_disk", + action="store_true", + help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", + ) + parser.add_argument( + "--text_encoder_batch_size", + type=int, + default=None, + help="text encoder batch size (default: None, use dataset's batch size)" + + " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)", + ) + + # Model loading optimization + parser.add_argument( + "--disable_mmap_load_safetensors", + action="store_true", + help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる", + ) + + # Training arguments. partial copy from Diffusers + parser.add_argument( + "--weighting_scheme", + type=str, + default="uniform", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none", "uniform"], + help="weighting scheme for timestep distribution. Default is uniform, uniform and none are the same behavior" + " / タイムステップ分布の重み付けスキーム、デフォルトはuniform、uniform と none は同じ挙動", + ) + parser.add_argument( + "--logit_mean", + type=float, + default=0.0, + help="mean to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合の平均", + ) + parser.add_argument( + "--logit_std", + type=float, + default=1.0, + help="std to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合のstd", + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme` / モード重み付けスキームのスケール", + ) + + # offloading + parser.add_argument( + "--blocks_to_swap", + type=int, + default=None, + help="[EXPERIMENTAL] " + "Sets the number of blocks to swap during the forward and backward passes." + "Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)." + " / 順伝播および逆伝播中にスワップするブロックの数を設定します。" + "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", + ) + + +def get_sanitized_config_or_none(args: argparse.Namespace): + # if `--log_config` is enabled, return args for logging. if not, return None. + # when `--log_config is enabled, filter out sensitive values from args + # if wandb is not enabled, the log is not exposed to the public, but it is fine to filter out sensitive values to be safe + + if not args.log_config: + return None + + sensitive_args = ["wandb_api_key", "huggingface_token"] + sensitive_path_args = [ + "pretrained_model_name_or_path", + "vae", + "tokenizer_cache_dir", + "train_data_dir", + "conditioning_data_dir", + "reg_data_dir", + "output_dir", + "logging_dir", + ] + filtered_args = {} + for k, v in vars(args).items(): + # filter out sensitive values and convert to string if necessary + if k not in sensitive_args + sensitive_path_args: + # Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`. + if v is None or isinstance(v, bool) or isinstance(v, str) or isinstance(v, float) or isinstance(v, int): + filtered_args[k] = v + # accelerate does not support lists + elif isinstance(v, list): + filtered_args[k] = f"{v}" + # accelerate does not support objects + elif isinstance(v, object): + filtered_args[k] = f"{v}" + + return filtered_args + + +# verify command line args for training +def verify_command_line_training_args(args: argparse.Namespace): + # if wandb is enabled, the command line is exposed to the public + # check whether sensitive options are included in the command line arguments + # if so, warn or inform the user to move them to the configuration file + # wandbが有効な場合、コマンドラインが公開される + # 学習用のコマンドライン引数に敏感なオプションが含まれているかどうかを確認し、 + # 含まれている場合は設定ファイルに移動するようにユーザーに警告または通知する + + wandb_enabled = args.log_with is not None and args.log_with != "tensorboard" # "all" or "wandb" + if not wandb_enabled: + return + + sensitive_args = ["wandb_api_key", "huggingface_token"] + sensitive_path_args = [ + "pretrained_model_name_or_path", + "vae", + "tokenizer_cache_dir", + "train_data_dir", + "conditioning_data_dir", + "reg_data_dir", + "output_dir", + "logging_dir", + ] + + for arg in sensitive_args: + if getattr(args, arg, None) is not None: + logger.warning( + f"wandb is enabled, but option `{arg}` is included in the command line. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file." + + f" / wandbが有効で、かつオプション `{arg}` がコマンドラインに含まれています。コマンドラインは公開されるため、`.toml`ファイルに移動することをお勧めします。" + ) + + # if path is absolute, it may include sensitive information + for arg in sensitive_path_args: + if getattr(args, arg, None) is not None and os.path.isabs(getattr(args, arg)): + logger.info( + f"wandb is enabled, but option `{arg}` is included in the command line and it is an absolute path. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file or use relative path." + + f" / wandbが有効で、かつオプション `{arg}` がコマンドラインに含まれており、絶対パスです。コマンドラインは公開されるため、`.toml`ファイルに移動するか、相対パスを使用することをお勧めします。" + ) + + if getattr(args, "config_file", None) is not None: + logger.info( + f"wandb is enabled, but option `config_file` is included in the command line. Because the command line is exposed to the public, please be careful about the information included in the path." + + f" / wandbが有効で、かつオプション `config_file` がコマンドラインに含まれています。コマンドラインは公開されるため、パスに含まれる情報にご注意ください。" + ) + + # other sensitive options + if args.huggingface_repo_id is not None and args.huggingface_repo_visibility != "public": + logger.info( + f"wandb is enabled, but option huggingface_repo_id is included in the command line and huggingface_repo_visibility is not 'public'. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file." + + f" / wandbが有効で、かつオプション huggingface_repo_id がコマンドラインに含まれており、huggingface_repo_visibility が 'public' ではありません。コマンドラインは公開されるため、`.toml`ファイルに移動することをお勧めします。" + ) + + +def enable_high_vram(args: argparse.Namespace): + if args.highvram: + logger.info("highvram is enabled / highvramが有効です") + global HIGH_VRAM + HIGH_VRAM = True + + +def verify_training_args(args: argparse.Namespace): + r""" + Verify training arguments. Also reflect highvram option to global variable + 学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する + """ + enable_high_vram(args) + + if args.v2 and args.clip_skip is not None: + logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + + if args.cache_latents_to_disk and not args.cache_latents: + args.cache_latents = True + logger.warning( + "cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします" + ) + + # noise_offset, perlin_noise, multires_noise_iterations cannot be enabled at the same time + # # Listを使って数えてもいいけど並べてしまえ + # if args.noise_offset is not None and args.multires_noise_iterations is not None: + # raise ValueError( + # "noise_offset and multires_noise_iterations cannot be enabled at the same time / noise_offsetとmultires_noise_iterationsを同時に有効にできません" + # ) + # if args.noise_offset is not None and args.perlin_noise is not None: + # raise ValueError("noise_offset and perlin_noise cannot be enabled at the same time / noise_offsetとperlin_noiseは同時に有効にできません") + # if args.perlin_noise is not None and args.multires_noise_iterations is not None: + # raise ValueError( + # "perlin_noise and multires_noise_iterations cannot be enabled at the same time / perlin_noiseとmultires_noise_iterationsを同時に有効にできません" + # ) + + if args.adaptive_noise_scale is not None and args.noise_offset is None: + raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です") + + if args.scale_v_pred_loss_like_noise_pred and not args.v_parameterization: + raise ValueError( + "scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます" + ) + + if args.v_pred_like_loss and args.v_parameterization: + raise ValueError( + "v_pred_like_loss cannot be enabled with v_parameterization / v_pred_like_lossはv_parameterizationが有効なときには有効にできません" + ) + + if args.zero_terminal_snr and not args.v_parameterization: + logger.warning( + f"zero_terminal_snr is enabled, but v_parameterization is not enabled. training will be unexpected" + + " / zero_terminal_snrが有効ですが、v_parameterizationが有効ではありません。学習結果は想定外になる可能性があります" + ) + + if args.sample_every_n_epochs is not None and args.sample_every_n_epochs <= 0: + logger.warning( + "sample_every_n_epochs is less than or equal to 0, so it will be disabled / sample_every_n_epochsに0以下の値が指定されたため無効になります" + ) + args.sample_every_n_epochs = None + + if args.sample_every_n_steps is not None and args.sample_every_n_steps <= 0: + logger.warning( + "sample_every_n_steps is less than or equal to 0, so it will be disabled / sample_every_n_stepsに0以下の値が指定されたため無効になります" + ) + args.sample_every_n_steps = None + + +def add_dataset_arguments( + parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool +): + # dataset common + parser.add_argument( + "--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ" + ) + parser.add_argument( + "--cache_info", + action="store_true", + help="cache meta information (caption and image size) for faster dataset loading. only available for DreamBooth" + + " / メタ情報(キャプションとサイズ)をキャッシュしてデータセット読み込みを高速化する。DreamBooth方式のみ有効", + ) + parser.add_argument( + "--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする" + ) + parser.add_argument("--caption_separator", type=str, default=",", help="separator for caption / captionの区切り文字") + parser.add_argument( + "--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子" + ) + parser.add_argument( + "--caption_extention", + type=str, + default=None, + help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)", + ) + parser.add_argument( + "--keep_tokens", + type=int, + default=0, + help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)", + ) + parser.add_argument( + "--keep_tokens_separator", + type=str, + default="", + help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens." + + " / captionを固定部分と可変部分に分けるためのカスタム区切り文字。この区切り文字より前のトークンはシャッフルされない。指定しない場合、'--keep_tokens'が固定部分のトークン数として使用される。", + ) + parser.add_argument( + "--secondary_separator", + type=str, + default=None, + help="a secondary separator for caption. This separator is replaced to caption_separator after dropping/shuffling caption" + + " / captionのセカンダリ区切り文字。この区切り文字はcaptionのドロップやシャッフル後にcaption_separatorに置き換えられる", + ) + parser.add_argument( + "--enable_wildcard", + action="store_true", + help="enable wildcard for caption (e.g. '{image|picture|rendition}') / captionのワイルドカードを有効にする(例:'{image|picture|rendition}')", + ) + parser.add_argument( + "--caption_prefix", + type=str, + default=None, + help="prefix for caption text / captionのテキストの先頭に付ける文字列", + ) + parser.add_argument( + "--caption_suffix", + type=str, + default=None, + help="suffix for caption text / captionのテキストの末尾に付ける文字列", + ) + parser.add_argument( + "--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする" + ) + parser.add_argument( + "--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする" + ) + parser.add_argument( + "--face_crop_aug_range", + type=str, + default=None, + help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)", + ) + parser.add_argument( + "--random_crop", + action="store_true", + help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)", + ) + parser.add_argument( + "--debug_dataset", + action="store_true", + help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)", + ) + parser.add_argument( + "--resolution", + type=str, + default=None, + help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)", + ) + parser.add_argument( + "--cache_latents", + action="store_true", + help="cache latents to main memory to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをメインメモリにcacheする(augmentationは使用不可) ", + ) + parser.add_argument( + "--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ" + ) + parser.add_argument( + "--cache_latents_to_disk", + action="store_true", + help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)", + ) + parser.add_argument( + "--skip_cache_check", + action="store_true", + help="skip the content validation of cache (latent and text encoder output). Cache file existence check is always performed, and cache processing is performed if the file does not exist" + " / cacheの内容の検証をスキップする(latentとテキストエンコーダの出力)。キャッシュファイルの存在確認は常に行われ、ファイルがなければキャッシュ処理が行われる", + ) + parser.add_argument( + "--enable_bucket", + action="store_true", + help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする", + ) + parser.add_argument( + "--min_bucket_reso", + type=int, + default=256, + help="minimum resolution for buckets, must be divisible by bucket_reso_steps " + " / bucketの最小解像度、bucket_reso_stepsで割り切れる必要があります", + ) + parser.add_argument( + "--max_bucket_reso", + type=int, + default=1024, + help="maximum resolution for buckets, must be divisible by bucket_reso_steps " + " / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります", + ) + parser.add_argument( + "--bucket_reso_steps", + type=int, + default=64, + help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します", + ) + parser.add_argument( + "--bucket_no_upscale", + action="store_true", + help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します", + ) + + parser.add_argument( + "--token_warmup_min", + type=int, + default=1, + help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する", + ) + parser.add_argument( + "--token_warmup_step", + type=float, + default=0, + help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)", + ) + parser.add_argument( + "--alpha_mask", + action="store_true", + help="use alpha channel as mask for training / 画像のアルファチャンネルをlossのマスクに使用する", + ) + + parser.add_argument( + "--dataset_class", + type=str, + default=None, + help="dataset class for arbitrary dataset (package.module.Class) / 任意のデータセットを用いるときのクラス名 (package.module.Class)", + ) + + if support_caption_dropout: + # Textual Inversion はcaptionのdropoutをsupportしない + # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに + parser.add_argument( + "--caption_dropout_rate", type=float, default=0.0, help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合" + ) + parser.add_argument( + "--caption_dropout_every_n_epochs", + type=int, + default=0, + help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする", + ) + parser.add_argument( + "--caption_tag_dropout_rate", + type=float, + default=0.0, + help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合", + ) + + if support_dreambooth: + # DreamBooth dataset + parser.add_argument( + "--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ" + ) + + if support_caption: + # caption dataset + parser.add_argument( + "--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル" + ) + parser.add_argument( + "--dataset_repeats", + type=int, + default=1, + help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数", + ) + + +def add_sd_saving_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--save_model_as", + type=str, + default=None, + choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"], + help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)", + ) + parser.add_argument( + "--use_safetensors", + action="store_true", + help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)", + ) + + +def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser): + if not args.config_file: + return args + + config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file + + if args.output_config: + # check if config file exists + if os.path.exists(config_path): + logger.error(f"Config file already exists. Aborting... / 出力先の設定ファイルが既に存在します: {config_path}") + exit(1) + + # convert args to dictionary + args_dict = vars(args) + + # remove unnecessary keys + for key in ["config_file", "output_config", "wandb_api_key"]: + if key in args_dict: + del args_dict[key] + + # get default args from parser + default_args = vars(parser.parse_args([])) + + # remove default values: cannot use args_dict.items directly because it will be changed during iteration + for key, value in list(args_dict.items()): + if key in default_args and value == default_args[key]: + del args_dict[key] + + # convert Path to str in dictionary + for key, value in args_dict.items(): + if isinstance(value, pathlib.Path): + args_dict[key] = str(value) + + # convert to toml and output to file + with open(config_path, "w") as f: + toml.dump(args_dict, f) + + logger.info(f"Saved config file / 設定ファイルを保存しました: {config_path}") + exit(0) + + if not os.path.exists(config_path): + logger.info(f"{config_path} not found.") + exit(1) + + logger.info(f"Loading settings from {config_path}...") + with open(config_path, "r", encoding="utf-8") as f: + config_dict = toml.load(f) + + # combine all sections into one + ignore_nesting_dict = {} + for section_name, section_dict in config_dict.items(): + # if value is not dict, save key and value as is + if not isinstance(section_dict, dict): + ignore_nesting_dict[section_name] = section_dict + continue + + # if value is dict, save all key and value into one dict + for key, value in section_dict.items(): + ignore_nesting_dict[key] = value + + config_args = argparse.Namespace(**ignore_nesting_dict) + args = parser.parse_args(namespace=config_args) + args.config_file = os.path.splitext(args.config_file)[0] + logger.info(args.config_file) + + return args + + +# endregion + +# region utils + + +def resume_from_local_or_hf_if_specified(accelerator, args): + if not args.resume: + return + + if not args.resume_from_huggingface: + logger.info(f"resume training from local state: {args.resume}") + accelerator.load_state(args.resume) + return + + logger.info(f"resume training from huggingface state: {args.resume}") + repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1] + path_in_repo = "/".join(args.resume.split("/")[2:]) + revision = None + repo_type = None + if ":" in path_in_repo: + divided = path_in_repo.split(":") + if len(divided) == 2: + path_in_repo, revision = divided + repo_type = "model" + else: + path_in_repo, revision, repo_type = divided + logger.info(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}") + + list_files = huggingface_util.list_dir( + repo_id=repo_id, + subfolder=path_in_repo, + revision=revision, + token=args.huggingface_token, + repo_type=repo_type, + ) + + async def download(filename) -> str: + def task(): + return hf_hub_download( + repo_id=repo_id, + filename=filename, + revision=revision, + repo_type=repo_type, + token=args.huggingface_token, + ) + + return await asyncio.get_event_loop().run_in_executor(None, task) + + loop = asyncio.get_event_loop() + results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files])) + if len(results) == 0: + raise ValueError( + "No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした" + ) + dirname = os.path.dirname(results[0]) + accelerator.load_state(dirname) + + +def get_optimizer(args, trainable_params): + # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" + + optimizer_type = args.optimizer_type + if args.use_8bit_adam: + assert ( + not args.use_lion_optimizer + ), "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています" + assert ( + optimizer_type is None or optimizer_type == "" + ), "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています" + optimizer_type = "AdamW8bit" + + elif args.use_lion_optimizer: + assert ( + optimizer_type is None or optimizer_type == "" + ), "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています" + optimizer_type = "Lion" + + if optimizer_type is None or optimizer_type == "": + optimizer_type = "AdamW" + optimizer_type = optimizer_type.lower() + + if args.fused_backward_pass: + assert ( + optimizer_type == "Adafactor".lower() + ), "fused_backward_pass currently only works with optimizer_type Adafactor / fused_backward_passは現在optimizer_type Adafactorでのみ機能します" + assert ( + args.gradient_accumulation_steps == 1 + ), "fused_backward_pass does not work with gradient_accumulation_steps > 1 / fused_backward_passはgradient_accumulation_steps>1では機能しません" + + # 引数を分解する + optimizer_kwargs = {} + if args.optimizer_args is not None and len(args.optimizer_args) > 0: + for arg in args.optimizer_args: + key, value = arg.split("=") + value = ast.literal_eval(value) + + # value = value.split(",") + # for i in range(len(value)): + # if value[i].lower() == "true" or value[i].lower() == "false": + # value[i] = value[i].lower() == "true" + # else: + # value[i] = ast.float(value[i]) + # if len(value) == 1: + # value = value[0] + # else: + # value = tuple(value) + + optimizer_kwargs[key] = value + # logger.info(f"optkwargs {optimizer}_{kwargs}") + + lr = args.learning_rate + optimizer = None + optimizer_class = None + + if optimizer_type == "Lion".lower(): + try: + import lion_pytorch + except ImportError: + raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです") + logger.info(f"use Lion optimizer | {optimizer_kwargs}") + optimizer_class = lion_pytorch.Lion + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type.endswith("8bit".lower()): + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") + + if optimizer_type == "AdamW8bit".lower(): + logger.info(f"use 8-bit AdamW optimizer | {optimizer_kwargs}") + optimizer_class = bnb.optim.AdamW8bit + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "SGDNesterov8bit".lower(): + logger.info(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}") + if "momentum" not in optimizer_kwargs: + logger.warning( + f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します" + ) + optimizer_kwargs["momentum"] = 0.9 + + optimizer_class = bnb.optim.SGD8bit + optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) + + elif optimizer_type == "Lion8bit".lower(): + logger.info(f"use 8-bit Lion optimizer | {optimizer_kwargs}") + try: + optimizer_class = bnb.optim.Lion8bit + except AttributeError: + raise AttributeError( + "No Lion8bit. The version of bitsandbytes installed seems to be old. Please install 0.38.0 or later. / Lion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.38.0以上をインストールしてください" + ) + elif optimizer_type == "PagedAdamW8bit".lower(): + logger.info(f"use 8-bit PagedAdamW optimizer | {optimizer_kwargs}") + try: + optimizer_class = bnb.optim.PagedAdamW8bit + except AttributeError: + raise AttributeError( + "No PagedAdamW8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) + elif optimizer_type == "PagedLion8bit".lower(): + logger.info(f"use 8-bit Paged Lion optimizer | {optimizer_kwargs}") + try: + optimizer_class = bnb.optim.PagedLion8bit + except AttributeError: + raise AttributeError( + "No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) + elif optimizer_type == "Ademamix8bit".lower(): + logger.info(f"use 8-bit Ademamix optimizer | {optimizer_kwargs}") + try: + optimizer_class = bnb.optim.AdEMAMix8bit + except AttributeError: + raise AttributeError( + "No Ademamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / Ademamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) + elif optimizer_type == "PagedAdemamix8bit".lower(): + logger.info(f"use 8-bit PagedAdemamix optimizer | {optimizer_kwargs}") + try: + optimizer_class = bnb.optim.PagedAdEMAMix8bit + except AttributeError: + raise AttributeError( + "No PagedAdemamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / PagedAdemamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "PagedAdamW".lower(): + logger.info(f"use PagedAdamW optimizer | {optimizer_kwargs}") + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") + try: + optimizer_class = bnb.optim.PagedAdamW + except AttributeError: + raise AttributeError( + "No PagedAdamW. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamWが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "PagedAdamW32bit".lower(): + logger.info(f"use 32-bit PagedAdamW optimizer | {optimizer_kwargs}") + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") + try: + optimizer_class = bnb.optim.PagedAdamW32bit + except AttributeError: + raise AttributeError( + "No PagedAdamW32bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW32bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "SGDNesterov".lower(): + logger.info(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") + if "momentum" not in optimizer_kwargs: + logger.info( + f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します" + ) + optimizer_kwargs["momentum"] = 0.9 + + optimizer_class = torch.optim.SGD + optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) + + elif optimizer_type.startswith("DAdapt".lower()) or optimizer_type == "Prodigy".lower(): + # check lr and lr_count, and logger.info warning + actual_lr = lr + lr_count = 1 + if type(trainable_params) == list and type(trainable_params[0]) == dict: + lrs = set() + actual_lr = trainable_params[0].get("lr", actual_lr) + for group in trainable_params: + lrs.add(group.get("lr", actual_lr)) + lr_count = len(lrs) + + if actual_lr <= 0.1: + logger.warning( + f"learning rate is too low. If using D-Adaptation or Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。D-AdaptationまたはProdigyの使用時は1.0前後の値を指定してください: lr={actual_lr}" + ) + logger.warning("recommend option: lr=1.0 / 推奨は1.0です") + if lr_count > 1: + logger.warning( + f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-AdaptationまたはProdigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" + ) + + if optimizer_type.startswith("DAdapt".lower()): + # DAdaptation family + # check dadaptation is installed + try: + import dadaptation + import dadaptation.experimental as experimental + except ImportError: + raise ImportError("No dadaptation / dadaptation がインストールされていないようです") + + # set optimizer + if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower(): + optimizer_class = experimental.DAdaptAdamPreprint + logger.info(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptAdaGrad".lower(): + optimizer_class = dadaptation.DAdaptAdaGrad + logger.info(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptAdam".lower(): + optimizer_class = dadaptation.DAdaptAdam + logger.info(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptAdan".lower(): + optimizer_class = dadaptation.DAdaptAdan + logger.info(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptAdanIP".lower(): + optimizer_class = experimental.DAdaptAdanIP + logger.info(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptLion".lower(): + optimizer_class = dadaptation.DAdaptLion + logger.info(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptSGD".lower(): + optimizer_class = dadaptation.DAdaptSGD + logger.info(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") + else: + raise ValueError(f"Unknown optimizer type: {optimizer_type}") + + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + else: + # Prodigy + # check Prodigy is installed + try: + import prodigyopt + except ImportError: + raise ImportError("No Prodigy / Prodigy がインストールされていないようです") + + logger.info(f"use Prodigy optimizer | {optimizer_kwargs}") + optimizer_class = prodigyopt.Prodigy + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "Adafactor".lower(): + # 引数を確認して適宜補正する + if "relative_step" not in optimizer_kwargs: + optimizer_kwargs["relative_step"] = True # default + if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False): + logger.info( + f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします" + ) + optimizer_kwargs["relative_step"] = True + logger.info(f"use Adafactor optimizer | {optimizer_kwargs}") + + if optimizer_kwargs["relative_step"]: + logger.info(f"relative_step is true / relative_stepがtrueです") + if lr != 0.0: + logger.warning(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます") + args.learning_rate = None + + # trainable_paramsがgroupだった時の処理:lrを削除する + if type(trainable_params) == list and type(trainable_params[0]) == dict: + has_group_lr = False + for group in trainable_params: + p = group.pop("lr", None) + has_group_lr = has_group_lr or (p is not None) + + if has_group_lr: + # 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない + logger.warning(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます") + args.unet_lr = None + args.text_encoder_lr = None + + if args.lr_scheduler != "adafactor": + logger.info(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します") + args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど + + lr = None + else: + if args.max_grad_norm != 0.0: + logger.warning( + f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません" + ) + if args.lr_scheduler != "constant_with_warmup": + logger.warning(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません") + if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0: + logger.warning(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません") + + optimizer_class = transformers.optimization.Adafactor + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type == "AdamW".lower(): + logger.info(f"use AdamW optimizer | {optimizer_kwargs}") + optimizer_class = torch.optim.AdamW + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type.endswith("schedulefree".lower()): + if optimizer_type.lower() == "ProdigyPlusScheduleFree".lower(): + from ..prodigyplusschedulefree.prodigy_plus_schedulefree import ProdigyPlusScheduleFree + optimizer_class = ProdigyPlusScheduleFree + logger.info(f"use ProdigyScheduleFree optimizer | {optimizer_kwargs}") + else: + try: + import schedulefree as sf + except ImportError: + raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") + if optimizer_type == "AdamWScheduleFree".lower(): + optimizer_class = sf.AdamWScheduleFree + logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}") + elif optimizer_type == "SGDScheduleFree".lower(): + optimizer_class = sf.SGDScheduleFree + logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}") + else: + raise ValueError(f"Unknown optimizer type: {optimizer_type}") + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + # make optimizer as train mode: we don't need to call train again, because eval will not be called in training loop + optimizer.train() + + elif optimizer_type == "CAME".lower(): + logger.info(f"use CAME optimizer | {optimizer_kwargs}") + try: + import came_pytorch + except ImportError: + raise ImportError("No came-pytorch") + try: + optimizer_class = came_pytorch.CAME + except AttributeError: + raise AttributeError( + "No CAME. Please install came-pytorch." + ) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + if optimizer is None: + # 任意のoptimizerを使う + case_sensitive_optimizer_type = args.optimizer_type # not lower + logger.info(f"use {case_sensitive_optimizer_type} | {optimizer_kwargs}") + if "." not in case_sensitive_optimizer_type: # from torch.optim + optimizer_module = torch.optim + else: # from other library + values = case_sensitive_optimizer_type.split(".") + optimizer_module = importlib.import_module(".".join(values[:-1])) + case_sensitive_optimizer_type = values[-1] + + optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ + optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) + + return optimizer_name, optimizer_args, optimizer + + +def get_optimizer_train_eval_fn(optimizer: Optimizer, args: argparse.Namespace) -> Tuple[Callable, Callable]: + if not is_schedulefree_optimizer(optimizer, args): + # return dummy func + return lambda: None, lambda: None + + # get train and eval functions from optimizer + train_fn = optimizer.train + eval_fn = optimizer.eval + + return train_fn, eval_fn + + +def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool: + return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper + + +def get_dummy_scheduler(optimizer: Optimizer) -> Any: + # dummy scheduler for schedulefree optimizer. supports only empty step(), get_last_lr() and optimizers. + # this scheduler is used for logging only. + # this isn't be wrapped by accelerator because of this class is not a subclass of torch.optim.lr_scheduler._LRScheduler + class DummyScheduler: + def __init__(self, optimizer: Optimizer): + self.optimizer = optimizer + + def step(self): + pass + + def get_last_lr(self): + return [group["lr"] for group in self.optimizer.param_groups] + + return DummyScheduler(optimizer) + + +# Modified version of get_scheduler() function from diffusers.optimizer.get_scheduler +# Add some checking and features to the original function. + + +def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): + """ + Unified API to get any scheduler from its name. + """ + # if schedulefree optimizer, return dummy scheduler + if is_schedulefree_optimizer(optimizer, args): + return get_dummy_scheduler(optimizer) + + name = args.lr_scheduler + num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps + num_warmup_steps: Optional[int] = ( + int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps + ) + num_decay_steps: Optional[int] = ( + int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps + ) + num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps + num_cycles = args.lr_scheduler_num_cycles + power = args.lr_scheduler_power + timescale = args.lr_scheduler_timescale + min_lr_ratio = args.lr_scheduler_min_lr_ratio + + lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs + if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0: + for arg in args.lr_scheduler_args: + key, value = arg.split("=") + value = ast.literal_eval(value) + lr_scheduler_kwargs[key] = value + + def wrap_check_needless_num_warmup_steps(return_vals): + if num_warmup_steps is not None and num_warmup_steps != 0: + raise ValueError(f"{name} does not require `num_warmup_steps`. Set None or 0.") + return return_vals + + # using any lr_scheduler from other library + if args.lr_scheduler_type: + lr_scheduler_type = args.lr_scheduler_type + logger.info(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler") + if "." not in lr_scheduler_type: # default to use torch.optim + lr_scheduler_module = torch.optim.lr_scheduler + else: + values = lr_scheduler_type.split(".") + lr_scheduler_module = importlib.import_module(".".join(values[:-1])) + lr_scheduler_type = values[-1] + lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type) + lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs) + return wrap_check_needless_num_warmup_steps(lr_scheduler) + + if name.startswith("adafactor"): + assert ( + type(optimizer) == transformers.optimization.Adafactor + ), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください" + initial_lr = float(name.split(":")[1]) + # logger.info(f"adafactor scheduler init lr {initial_lr}") + return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) + + if name == DiffusersSchedulerType.PIECEWISE_CONSTANT.value: + name = DiffusersSchedulerType(name) + schedule_func = DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name] + return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs + + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + + if name == SchedulerType.CONSTANT: + return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs)) + + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs) + + if name == SchedulerType.INVERSE_SQRT: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs) + + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + + if name == SchedulerType.COSINE_WITH_RESTARTS: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + **lr_scheduler_kwargs, + ) + + if name == SchedulerType.POLYNOMIAL: + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs + ) + + if name == SchedulerType.COSINE_WITH_MIN_LR: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles / 2, + min_lr_rate=min_lr_ratio, + **lr_scheduler_kwargs, + ) + + # these schedulers do not require `num_decay_steps` + if name == SchedulerType.LINEAR or name == SchedulerType.COSINE: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + **lr_scheduler_kwargs, + ) + + # All other schedulers require `num_decay_steps` + if num_decay_steps is None: + raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.") + if name == SchedulerType.WARMUP_STABLE_DECAY: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_stable_steps=num_stable_steps, + num_decay_steps=num_decay_steps, + num_cycles=num_cycles / 2, + min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0, + **lr_scheduler_kwargs, + ) + + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_decay_steps=num_decay_steps, + **lr_scheduler_kwargs, + ) + + +def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): + # backward compatibility + if args.caption_extention is not None: + args.caption_extension = args.caption_extention + args.caption_extention = None + + # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください" + if args.resolution is not None: + args.resolution = tuple([int(r) for r in args.resolution.split(",")]) + if len(args.resolution) == 1: + args.resolution = (args.resolution[0], args.resolution[0]) + assert ( + len(args.resolution) == 2 + ), f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}" + + if args.face_crop_aug_range is not None: + args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(",")]) + assert ( + len(args.face_crop_aug_range) == 2 and args.face_crop_aug_range[0] <= args.face_crop_aug_range[1] + ), f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}" + else: + args.face_crop_aug_range = None + + if support_metadata: + if args.in_json is not None and (args.color_aug or args.random_crop): + logger.warning( + f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます" + ) + + +def prepare_accelerator(args: argparse.Namespace): + """ + this function also prepares deepspeed plugin + """ + + if args.logging_dir is None: + logging_dir = None + else: + log_prefix = "" if args.log_prefix is None else args.log_prefix + logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime()) + + if args.log_with is None: + if logging_dir is not None: + log_with = "tensorboard" + else: + log_with = None + else: + log_with = args.log_with + if log_with in ["tensorboard", "all"]: + if logging_dir is None: + raise ValueError( + "logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください" + ) + if log_with in ["wandb", "all"]: + try: + import wandb + except ImportError: + raise ImportError("No wandb / wandb がインストールされていないようです") + if logging_dir is not None: + os.makedirs(logging_dir, exist_ok=True) + os.environ["WANDB_DIR"] = logging_dir + if args.wandb_api_key is not None: + wandb.login(key=args.wandb_api_key) + + # torch.compile If NO, torch.compile is not used." + dynamo_backend = "NO" + if args.torch_compile: + dynamo_backend = args.dynamo_backend + + kwargs_handlers = [ + ( + InitProcessGroupKwargs( + backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + init_method=( + "env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None + ), + timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None, + ) + if torch.cuda.device_count() > 1 + else None + ), + ( + DistributedDataParallelKwargs( + gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph + ) + if args.ddp_gradient_as_bucket_view or args.ddp_static_graph + else None + ), + ] + kwargs_handlers = [i for i in kwargs_handlers if i is not None] + deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=log_with, + project_dir=logging_dir, + kwargs_handlers=kwargs_handlers, + dynamo_backend=dynamo_backend, + deepspeed_plugin=deepspeed_plugin, + ) + print("accelerator device:", accelerator.device) + return accelerator + + +def prepare_dtype(args: argparse.Namespace): + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + save_dtype = None + if args.save_precision == "fp16": + save_dtype = torch.float16 + elif args.save_precision == "bf16": + save_dtype = torch.bfloat16 + elif args.save_precision == "fp32": + save_dtype = torch.float32 + elif args.save_precision == "fp8_e4m3fn": + save_dtype = torch.float8_e4m3fn + elif args.save_precision == "fp8_e5m2": + save_dtype = torch.float8_e5m2 + + return weight_dtype, save_dtype + + +def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", unet_use_linear_projection_in_v2=False): + name_or_path = args.pretrained_model_name_or_path + name_or_path = os.path.realpath(name_or_path) if os.path.islink(name_or_path) else name_or_path + load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers + if load_stable_diffusion_format: + logger.info(f"load StableDiffusion checkpoint: {name_or_path}") + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint( + args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2 + ) + else: + # Diffusers model is loaded to CPU + logger.info(f"load Diffusers pretrained models: {name_or_path}") + try: + pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None) + except EnvironmentError as ex: + logger.error( + f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" + ) + raise ex + text_encoder = pipe.text_encoder + vae = pipe.vae + unet = pipe.unet + del pipe + + # Diffusers U-Net to original U-Net + # TODO *.ckpt/*.safetensorsのv2と同じ形式にここで変換すると良さそう + # logger.info(f"unet config: {unet.config}") + original_unet = UNet2DConditionModel( + unet.config.sample_size, + unet.config.attention_head_dim, + unet.config.cross_attention_dim, + unet.config.use_linear_projection, + unet.config.upcast_attention, + ) + original_unet.load_state_dict(unet.state_dict()) + unet = original_unet + logger.info("U-Net converted to original U-Net") + + # VAEを読み込む + if args.vae is not None: + vae = model_util.load_vae(args.vae, weight_dtype) + logger.info("additional VAE loaded") + + return text_encoder, vae, unet, load_stable_diffusion_format + + +def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False): + for pi in range(accelerator.state.num_processes): + if pi == accelerator.state.local_process_index: + logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") + + text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model( + args, + weight_dtype, + accelerator.device if args.lowram else "cpu", + unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2, + ) + # work on low-ram device + if args.lowram: + text_encoder.to(accelerator.device) + unet.to(accelerator.device) + vae.to(accelerator.device) + + clean_memory_on_device(accelerator.device) + accelerator.wait_for_everyone() + return text_encoder, vae, unet, load_stable_diffusion_format + + +def patch_accelerator_for_fp16_training(accelerator): + org_unscale_grads = accelerator.scaler._unscale_grads_ + + def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): + return org_unscale_grads(optimizer, inv_scale, found_inf, True) + + accelerator.scaler._unscale_grads_ = _unscale_grads_replacer + + +def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None): + # with no_token_padding, the length is not max length, return result immediately + if input_ids.size()[-1] != tokenizer.model_max_length: + return text_encoder(input_ids)[0] + + # input_ids: b,n,77 + b_size = input_ids.size()[0] + input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77 + + if args.clip_skip is None: + encoder_hidden_states = text_encoder(input_ids)[0] + else: + enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out["hidden_states"][-args.clip_skip] + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) + + # bs*3, 77, 768 or 1024 + encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) + + if args.max_token_length is not None: + if args.v2: + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer.model_max_length): + chunk = encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2] # の後から 最後の前まで + if i > 0: + for j in range(len(chunk)): + if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり ...のパターン + chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか + encoder_hidden_states = torch.cat(states_list, dim=1) + else: + # v1: ... の三連を ... へ戻す + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer.model_max_length): + states_list.append( + encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2] + ) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # + encoder_hidden_states = torch.cat(states_list, dim=1) + + if weight_dtype is not None: + # this is required for additional network training + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + + return encoder_hidden_states + + +def pool_workaround( + text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int +): + r""" + workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output + instead of the hidden states for the EOS token + If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output + + Original code from CLIP's pooling function: + + \# text_embeds.shape = [batch_size, sequence_length, transformer.width] + \# take features from the eot embedding (eot_token is the highest number in each sequence) + \# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), + ] + """ + + # input_ids: b*n,77 + # find index for EOS token + + # Following code is not working if one of the input_ids has multiple EOS tokens (very odd case) + # eos_token_index = torch.where(input_ids == eos_token_id)[1] + # eos_token_index = eos_token_index.to(device=last_hidden_state.device) + + # Create a mask where the EOS tokens are + eos_token_mask = (input_ids == eos_token_id).int() + + # Use argmax to find the last index of the EOS token for each element in the batch + eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine + eos_token_index = eos_token_index.to(device=last_hidden_state.device) + + # get hidden states for EOS token + pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index] + + # apply projection: projection may be of different dtype than last_hidden_state + pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype)) + pooled_output = pooled_output.to(last_hidden_state.dtype) + + return pooled_output + + +def get_hidden_states_sdxl( + max_token_length: int, + input_ids1: torch.Tensor, + input_ids2: torch.Tensor, + tokenizer1: CLIPTokenizer, + tokenizer2: CLIPTokenizer, + text_encoder1: CLIPTextModel, + text_encoder2: CLIPTextModelWithProjection, + weight_dtype: Optional[str] = None, + accelerator: Optional[Accelerator] = None, +): + # input_ids: b,n,77 -> b*n, 77 + b_size = input_ids1.size()[0] + input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77 + input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77 + + # text_encoder1 + enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True) + hidden_states1 = enc_out["hidden_states"][11] + + # text_encoder2 + enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True) + hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer + + # pool2 = enc_out["text_embeds"] + unwrapped_text_encoder2 = text_encoder2 if accelerator is None else accelerator.unwrap_model(text_encoder2) + pool2 = pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id) + + # b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280 + n_size = 1 if max_token_length is None else max_token_length // 75 + hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1])) + hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1])) + + if max_token_length is not None: + # bs*3, 77, 768 or 1024 + # encoder1: ... の三連を ... へ戻す + states_list = [hidden_states1[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, tokenizer1.model_max_length): + states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # の後から の前まで + states_list.append(hidden_states1[:, -1].unsqueeze(1)) # + hidden_states1 = torch.cat(states_list, dim=1) + + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [hidden_states2[:, 0].unsqueeze(1)] # + for i in range(1, max_token_length, tokenizer2.model_max_length): + chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # の後から 最後の前まで + # this causes an error: + # RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation + # if i > 1: + # for j in range(len(chunk)): # batch_size + # if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり ...のパターン + # chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(hidden_states2[:, -1].unsqueeze(1)) # のどちらか + hidden_states2 = torch.cat(states_list, dim=1) + + # pool はnの最初のものを使う + pool2 = pool2[::n_size] + + if weight_dtype is not None: + # this is required for additional network training + hidden_states1 = hidden_states1.to(weight_dtype) + hidden_states2 = hidden_states2.to(weight_dtype) + + return hidden_states1, hidden_states2, pool2 + + +def default_if_none(value, default): + return default if value is None else value + + +def get_epoch_ckpt_name(args: argparse.Namespace, ext: str, epoch_no: int): + model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME) + return EPOCH_FILE_NAME.format(model_name, epoch_no) + ext + + +def get_step_ckpt_name(args: argparse.Namespace, ext: str, step_no: int): + model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME) + return STEP_FILE_NAME.format(model_name, step_no) + ext + + +def get_last_ckpt_name(args: argparse.Namespace, ext: str): + model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME) + return model_name + ext + + +def get_remove_epoch_no(args: argparse.Namespace, epoch_no: int): + if args.save_last_n_epochs is None: + return None + + remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs + if remove_epoch_no < 0: + return None + return remove_epoch_no + + +def get_remove_step_no(args: argparse.Namespace, step_no: int): + if args.save_last_n_steps is None: + return None + + # last_n_steps前のstep_noから、save_every_n_stepsの倍数のstep_noを計算して削除する + # save_every_n_steps=10, save_last_n_steps=30の場合、50step目には30step分残し、10step目を削除する + remove_step_no = step_no - args.save_last_n_steps - 1 + remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps) + if remove_step_no < 0: + return None + return remove_step_no + + +# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している +# on_epoch_end: Trueならepoch終了時、Falseならstep経過時 +def save_sd_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + src_path: str, + save_stable_diffusion_format: bool, + use_safetensors: bool, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + text_encoder, + unet, + vae, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True) + model_util.save_stable_diffusion_checkpoint( + args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, sai_metadata, save_dtype, vae + ) + + def diffusers_saver(out_dir): + model_util.save_diffusers_checkpoint( + args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors + ) + + save_sd_model_on_epoch_end_or_stepwise_common( + args, + on_epoch_end, + accelerator, + save_stable_diffusion_format, + use_safetensors, + epoch, + num_train_epochs, + global_step, + sd_saver, + diffusers_saver, + ) + + +def save_sd_model_on_epoch_end_or_stepwise_common( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + save_stable_diffusion_format: bool, + use_safetensors: bool, + epoch: int, + num_train_epochs: int, + global_step: int, + sd_saver, + diffusers_saver, +): + if on_epoch_end: + epoch_no = epoch + 1 + saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs + if not saving: + return + + model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME) + remove_no = get_remove_epoch_no(args, epoch_no) + else: + # 保存するか否かは呼び出し側で判断済み + + model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME) + epoch_no = epoch # 例: 最初のepochの途中で保存したら0になる、SDモデルに保存される + remove_no = get_remove_step_no(args, global_step) + + os.makedirs(args.output_dir, exist_ok=True) + if save_stable_diffusion_format: + ext = ".safetensors" if use_safetensors else ".ckpt" + + if on_epoch_end: + ckpt_name = get_epoch_ckpt_name(args, ext, epoch_no) + else: + ckpt_name = get_step_ckpt_name(args, ext, global_step) + + ckpt_file = os.path.join(args.output_dir, ckpt_name) + logger.info("") + logger.info(f"saving checkpoint: {ckpt_file}") + sd_saver(ckpt_file, epoch_no, global_step) + + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name) + + # remove older checkpoints + if remove_no is not None: + if on_epoch_end: + remove_ckpt_name = get_epoch_ckpt_name(args, ext, remove_no) + else: + remove_ckpt_name = get_step_ckpt_name(args, ext, remove_no) + + remove_ckpt_file = os.path.join(args.output_dir, remove_ckpt_name) + if os.path.exists(remove_ckpt_file): + logger.info(f"removing old checkpoint: {remove_ckpt_file}") + os.remove(remove_ckpt_file) + + else: + if on_epoch_end: + out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no)) + else: + out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, global_step)) + + logger.info("") + logger.info(f"saving model: {out_dir}") + diffusers_saver(out_dir) + + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, out_dir, "/" + model_name) + + # remove older checkpoints + if remove_no is not None: + if on_epoch_end: + remove_out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, remove_no)) + else: + remove_out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, remove_no)) + + if os.path.exists(remove_out_dir): + logger.info(f"removing old model: {remove_out_dir}") + shutil.rmtree(remove_out_dir) + + if args.save_state: + if on_epoch_end: + save_and_remove_state_on_epoch_end(args, accelerator, epoch_no) + else: + save_and_remove_state_stepwise(args, accelerator, global_step) + + +def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, epoch_no): + model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME) + + logger.info("") + logger.info(f"saving state at epoch {epoch_no}") + os.makedirs(args.output_dir, exist_ok=True) + + state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)) + accelerator.save_state(state_dir) + if args.save_state_to_huggingface: + logger.info("uploading state to huggingface.") + huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no)) + + last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs + if last_n_epochs is not None: + remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs + state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) + if os.path.exists(state_dir_old): + logger.info(f"removing old state: {state_dir_old}") + shutil.rmtree(state_dir_old) + + +def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_no): + model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME) + + logger.info("") + logger.info(f"saving state at step {step_no}") + os.makedirs(args.output_dir, exist_ok=True) + + state_dir = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, step_no)) + accelerator.save_state(state_dir) + if args.save_state_to_huggingface: + logger.info("uploading state to huggingface.") + huggingface_util.upload(args, state_dir, "/" + STEP_STATE_NAME.format(model_name, step_no)) + + last_n_steps = args.save_last_n_steps_state if args.save_last_n_steps_state else args.save_last_n_steps + if last_n_steps is not None: + # last_n_steps前のstep_noから、save_every_n_stepsの倍数のstep_noを計算して削除する + remove_step_no = step_no - last_n_steps - 1 + remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps) + + if remove_step_no > 0: + state_dir_old = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, remove_step_no)) + if os.path.exists(state_dir_old): + logger.info(f"removing old state: {state_dir_old}") + shutil.rmtree(state_dir_old) + + +def save_state_on_train_end(args: argparse.Namespace, accelerator): + model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME) + + logger.info("") + logger.info("saving last state.") + os.makedirs(args.output_dir, exist_ok=True) + + state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)) + accelerator.save_state(state_dir) + + if args.save_state_to_huggingface: + logger.info("uploading last state to huggingface.") + huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name)) + + +def save_sd_model_on_train_end( + args: argparse.Namespace, + src_path: str, + save_stable_diffusion_format: bool, + use_safetensors: bool, + save_dtype: torch.dtype, + epoch: int, + global_step: int, + text_encoder, + unet, + vae, +): + def sd_saver(ckpt_file, epoch_no, global_step): + sai_metadata = get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True) + model_util.save_stable_diffusion_checkpoint( + args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, sai_metadata, save_dtype, vae + ) + + def diffusers_saver(out_dir): + model_util.save_diffusers_checkpoint( + args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors + ) + + save_sd_model_on_train_end_common( + args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver + ) + + +def save_sd_model_on_train_end_common( + args: argparse.Namespace, + save_stable_diffusion_format: bool, + use_safetensors: bool, + epoch: int, + global_step: int, + sd_saver, + diffusers_saver, +): + model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME) + + if save_stable_diffusion_format: + os.makedirs(args.output_dir, exist_ok=True) + + ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt") + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + logger.info(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") + sd_saver(ckpt_file, epoch, global_step) + + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True) + else: + out_dir = os.path.join(args.output_dir, model_name) + os.makedirs(out_dir, exist_ok=True) + + logger.info(f"save trained model as Diffusers to {out_dir}") + diffusers_saver(out_dir) + + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) + + +def get_timesteps(min_timestep, max_timestep, b_size, device): + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") + timesteps = timesteps.long().to(device) + return timesteps + + +def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + if args.noise_offset_random_strength: + noise_offset = torch.rand(1, device=latents.device) * args.noise_offset + else: + noise_offset = args.noise_offset + noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale) + if args.multires_noise_iterations: + noise = custom_train_functions.pyramid_noise_like( + noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount + ) + + # Sample a random timestep for each image + b_size = latents.shape[0] + min_timestep = 0 if args.min_timestep is None else args.min_timestep + max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep + + timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device) + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.ip_noise_gamma: + if args.ip_noise_gamma_random_strength: + strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma + else: + strength = args.ip_noise_gamma + noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps) + else: + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + return noise, noisy_latents, timesteps + + +def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]: + if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"): + return None + + b_size = timesteps.shape[0] + if args.huber_schedule == "exponential": + alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps + result = torch.exp(-alpha * timesteps) * args.huber_scale + elif args.huber_schedule == "snr": + if not hasattr(noise_scheduler, "alphas_cumprod"): + raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") + alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) + sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 + result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c + result = result.to(timesteps.device) + elif args.huber_schedule == "constant": + result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device) + else: + raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") + + return result + + +def conditional_loss( + model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None +): + if loss_type == "l2": + loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) + elif loss_type == "l1": + loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction) + elif loss_type == "huber": + huber_c = huber_c.view(-1, 1, 1, 1) + loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) + if reduction == "mean": + loss = torch.mean(loss) + elif reduction == "sum": + loss = torch.sum(loss) + elif loss_type == "smooth_l1": + huber_c = huber_c.view(-1, 1, 1, 1) + loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) + if reduction == "mean": + loss = torch.mean(loss) + elif reduction == "sum": + loss = torch.sum(loss) + else: + raise NotImplementedError(f"Unsupported Loss Type: {loss_type}") + return loss + + +def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True): + names = [] + if including_unet: + names.append("unet") + names.append("text_encoder1") + names.append("text_encoder2") + + append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names) + + +def append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names): + lrs = lr_scheduler.get_last_lr() + + for lr_index in range(len(lrs)): + name = names[lr_index] + logs["lr/" + name] = float(lrs[lr_index]) + + if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower(): + logs["lr/d*lr/" + name] = ( + lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"] + ) + + +# scheduler: +SCHEDULER_LINEAR_START = 0.00085 +SCHEDULER_LINEAR_END = 0.0120 +SCHEDULER_TIMESTEPS = 1000 +SCHEDLER_SCHEDULE = "scaled_linear" + + +def get_my_scheduler( + *, + sample_sampler: str, + v_parameterization: bool, +): + sched_init_args = {} + if sample_sampler == "ddim": + scheduler_cls = DDIMScheduler + elif sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある + scheduler_cls = DDPMScheduler + elif sample_sampler == "pndm": + scheduler_cls = PNDMScheduler + elif sample_sampler == "lms" or sample_sampler == "k_lms": + scheduler_cls = LMSDiscreteScheduler + elif sample_sampler == "euler" or sample_sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + elif sample_sampler == "euler_a" or sample_sampler == "k_euler_a": + scheduler_cls = EulerAncestralDiscreteScheduler + elif sample_sampler == "dpmsolver" or sample_sampler == "dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = sample_sampler + elif sample_sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + elif sample_sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + elif sample_sampler == "dpm_2" or sample_sampler == "k_dpm_2": + scheduler_cls = KDPM2DiscreteScheduler + elif sample_sampler == "dpm_2_a" or sample_sampler == "k_dpm_2_a": + scheduler_cls = KDPM2AncestralDiscreteScheduler + else: + scheduler_cls = DDIMScheduler + + if v_parameterization: + sched_init_args["prediction_type"] = "v_prediction" + + scheduler = scheduler_cls( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + **sched_init_args, + ) + + # clip_sample=Trueにする + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: + # logger.info("set clip_sample to True") + scheduler.config.clip_sample = True + + return scheduler + + +def sample_images(*args, **kwargs): + return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs) + + +def line_to_prompt_dict(line: str) -> dict: + # subset of gen_img_diffusers + prompt_args = line.split(" --") + prompt_dict = {} + prompt_dict["prompt"] = prompt_args[0] + + for parg in prompt_args: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + prompt_dict["width"] = int(m.group(1)) + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + prompt_dict["height"] = int(m.group(1)) + continue + + m = re.match(r"d (\d+)", parg, re.IGNORECASE) + if m: + prompt_dict["seed"] = int(m.group(1)) + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + prompt_dict["sample_steps"] = max(1, min(1000, int(m.group(1)))) + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + prompt_dict["scale"] = float(m.group(1)) + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + prompt_dict["negative_prompt"] = m.group(1) + continue + + m = re.match(r"ss (.+)", parg, re.IGNORECASE) + if m: + prompt_dict["sample_sampler"] = m.group(1) + continue + + m = re.match(r"cn (.+)", parg, re.IGNORECASE) + if m: + prompt_dict["controlnet_image"] = m.group(1) + continue + + except ValueError as ex: + logger.error(f"Exception in parsing / 解析エラー: {parg}") + logger.error(ex) + + return prompt_dict + + +def load_prompts(prompt_file: str) -> List[Dict]: + # read prompts + if prompt_file.endswith(".txt"): + with open(prompt_file, "r", encoding="utf-8") as f: + lines = f.readlines() + prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] + elif prompt_file.endswith(".toml"): + with open(prompt_file, "r", encoding="utf-8") as f: + data = toml.load(f) + prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]] + elif prompt_file.endswith(".json"): + with open(prompt_file, "r", encoding="utf-8") as f: + prompts = json.load(f) + + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + from library.train_util import line_to_prompt_dict + + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + return prompts + + +def sample_images_common( + pipe_class, + accelerator: Accelerator, + args: argparse.Namespace, + epoch, + steps, + device, + vae, + tokenizer, + text_encoder, + unet, + validation_settings=None, + prompt_replacement=None, + controlnet=None, + +): + """ + StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した + TODO Use strategies here + """ + + logger.info("") + logger.info(f"generating sample images at step: {steps}") + + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + + org_vae_device = vae.device # CPUにいるはず + vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device + + # unwrap unet and text_encoder(s) + unet = accelerator.unwrap_model(unet) + if isinstance(text_encoder, (list, tuple)): + text_encoder = [accelerator.unwrap_model(te) for te in text_encoder] + else: + text_encoder = accelerator.unwrap_model(text_encoder) + + prompts = [] + for line in args.sample_prompts: + line = line.strip() + if len(line) > 0 and line[0] != "#": + prompts.append(line) + + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + from .train_util import line_to_prompt_dict + + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + default_scheduler = get_my_scheduler(sample_sampler=args.sample_sampler, v_parameterization=args.v_parameterization) + + pipeline = pipe_class( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=default_scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + clip_skip=args.clip_skip, + ) + pipeline.to(distributed_state.device) + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + + # save random state to restore later + rng_state = torch.get_rng_state() + cuda_rng_state = None + try: + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + except Exception: + pass + + with torch.no_grad(): + image_tensor_list = [] + for prompt_dict in prompts: + image_tensor = sample_image_inference( + accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet, validation_settings=validation_settings + ) + image_tensor_list.append(image_tensor) + + # clear pipeline and cache to reduce vram usage + del pipeline + + torch.set_rng_state(rng_state) + if torch.cuda.is_available() and cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + vae.to(org_vae_device) + + clean_memory_on_device(accelerator.device) + return torch.cat(image_tensor_list, dim=0) + + +def sample_image_inference( + accelerator: Accelerator, + args: argparse.Namespace, + pipeline, + save_dir, + prompt_dict, + epoch, + steps, + prompt_replacement=None, + controlnet=None, + validation_settings=None, +): + assert isinstance(prompt_dict, dict) + if validation_settings is not None: + sample_steps = validation_settings["steps"] + width = validation_settings["width"] + height = validation_settings["height"] + scale = validation_settings["guidance_scale"] + sampler_name = validation_settings["sampler"] + seed = validation_settings["seed"] + controlnet_image=None + else: + sample_steps = prompt_dict.get("sample_steps", 30) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 7.5) + seed = prompt_dict.get("seed") + controlnet_image = prompt_dict.get("controlnet_image") + sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + prompt: str = prompt_dict.get("prompt", "") + negative_prompt = prompt_dict.get("negative_prompt") + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if seed is not None: + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + if torch.cuda.is_available(): + torch.cuda.seed() + + scheduler = get_my_scheduler( + sample_sampler=sampler_name, + v_parameterization=args.v_parameterization, + ) + pipeline.scheduler = scheduler + + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + logger.info(f"prompt: {prompt}") + logger.info(f"negative_prompt: {negative_prompt}") + logger.info(f"height: {height}") + logger.info(f"width: {width}") + logger.info(f"sample_steps: {sample_steps}") + logger.info(f"scale: {scale}") + logger.info(f"sample_sampler: {sampler_name}") + if seed is not None: + logger.info(f"seed: {seed}") + with accelerator.autocast(): + latents = pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=sample_steps, + guidance_scale=scale, + negative_prompt=negative_prompt, + controlnet=controlnet, + controlnet_image=controlnet_image, + ) + + if torch.cuda.is_available(): + with torch.cuda.device(torch.cuda.current_device()): + torch.cuda.empty_cache() + + image, tensors = pipeline.latents_to_image(latents) + + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + image[0].save(os.path.join(save_dir, img_filename)) + + + # send images to wandb if enabled + if "wandb" in [tracker.name for tracker in accelerator.trackers]: + wandb_tracker = accelerator.get_tracker("wandb") + + import wandb + + # not to commit images to avoid inconsistency between training and logging steps + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption + return tensors.permute(0, 2, 3, 1) + +# endregion + + +# region 前処理用 + + +class ImageLoadingDataset(torch.utils.data.Dataset): + def __init__(self, image_paths): + self.images = image_paths + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + img_path = self.images[idx] + + try: + image = Image.open(img_path).convert("RGB") + # convert to tensor temporarily so dataloader will accept it + tensor_pil = transforms.functional.pil_to_tensor(image) + except Exception as e: + logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") + return None + + return (tensor_pil, img_path) + + +# endregion + + +# collate_fn用 epoch,stepはmultiprocessing.Value +class collator_class: + def __init__(self, epoch, step, dataset): + self.current_epoch = epoch + self.current_step = step + self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing + + def __call__(self, examples): + worker_info = torch.utils.data.get_worker_info() + # worker_info is None in the main process + if worker_info is not None: + dataset = worker_info.dataset + else: + dataset = self.dataset + + # set epoch and step + dataset.set_current_epoch(self.current_epoch.value) + dataset.set_current_step(self.current_step.value) + return examples[0] + + +class LossRecorder: + def __init__(self): + self.loss_list: List[float] = [] + self.global_loss_list: List[float] = [] + self.loss_total: float = 0.0 + + def add(self, *, epoch: int, step: int, global_step: int, loss: float) -> None: + if epoch == 0: + self.loss_list.append(loss) + self.global_loss_list.append(loss) + else: + while len(self.loss_list) <= step: + self.loss_list.append(0.0) + self.loss_total -= self.loss_list[step] + self.loss_list[step] = loss + self.global_loss_list.append(loss) + self.loss_total += loss + + @property + def moving_average(self) -> float: + return self.loss_total / len(self.loss_list) diff --git a/library/utils.py b/library/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cbbf3f51bb183d92606c02dda05ce619f23779a8 --- /dev/null +++ b/library/utils.py @@ -0,0 +1,596 @@ +import logging +import sys +import threading +import torch +from torchvision import transforms +from typing import * +import json +import struct +from diffusers import EulerAncestralDiscreteScheduler +import diffusers.schedulers.scheduling_euler_ancestral_discrete +from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput +import cv2 +from PIL import Image +import numpy as np +from safetensors.torch import load_file + +def load_safetensors( + path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32 +): + if disable_mmap: + # return safetensors.torch.load(open(path, "rb").read()) + # use experimental loader + #logger.info(f"Loading without mmap (experimental)") + state_dict = {} + with MemoryEfficientSafeOpen(path) as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key).to(device, dtype=dtype) + return state_dict + else: + try: + return load_file(path, device=device) + except: + return load_file(path) # prevent device invalid Error + +def fire_in_thread(f, *args, **kwargs): + threading.Thread(target=f, args=args, kwargs=kwargs).start() + + +def add_logging_arguments(parser): + parser.add_argument( + "--console_log_level", + type=str, + default=None, + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO", + ) + parser.add_argument( + "--console_log_file", + type=str, + default=None, + help="Log to a file instead of stderr / 標準エラー出力ではなくファイルにログを出力する", + ) + parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力") + + +def setup_logging(args=None, log_level=None, reset=False): + if logging.root.handlers: + if reset: + # remove all handlers + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + else: + return + + # log_level can be set by the caller or by the args, the caller has priority. If not set, use INFO + if log_level is None and args is not None: + log_level = args.console_log_level + if log_level is None: + log_level = "INFO" + log_level = getattr(logging, log_level) + + msg_init = None + if args is not None and args.console_log_file: + handler = logging.FileHandler(args.console_log_file, mode="w") + else: + handler = None + if not args or not args.console_log_simple: + try: + from rich.logging import RichHandler + from rich.console import Console + from rich.logging import RichHandler + + handler = RichHandler(console=Console(stderr=True)) + except ImportError: + # print("rich is not installed, using basic logging") + msg_init = "rich is not installed, using basic logging" + + if handler is None: + handler = logging.StreamHandler(sys.stdout) # same as print + handler.propagate = False + + formatter = logging.Formatter( + fmt="%(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + handler.setFormatter(formatter) + logging.root.setLevel(log_level) + logging.root.addHandler(handler) + + if msg_init is not None: + logger = logging.getLogger(__name__) + logger.info(msg_init) + + +def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: + """ + Convert a string to a torch.dtype + + Args: + s: string representation of the dtype + default_dtype: default dtype to return if s is None + + Returns: + torch.dtype: the corresponding torch.dtype + + Raises: + ValueError: if the dtype is not supported + + Examples: + >>> str_to_dtype("float32") + torch.float32 + >>> str_to_dtype("fp32") + torch.float32 + >>> str_to_dtype("float16") + torch.float16 + >>> str_to_dtype("fp16") + torch.float16 + >>> str_to_dtype("bfloat16") + torch.bfloat16 + >>> str_to_dtype("bf16") + torch.bfloat16 + >>> str_to_dtype("fp8") + torch.float8_e4m3fn + >>> str_to_dtype("fp8_e4m3fn") + torch.float8_e4m3fn + >>> str_to_dtype("fp8_e4m3fnuz") + torch.float8_e4m3fnuz + >>> str_to_dtype("fp8_e5m2") + torch.float8_e5m2 + >>> str_to_dtype("fp8_e5m2fnuz") + torch.float8_e5m2fnuz + """ + if s is None: + return default_dtype + if s in ["bf16", "bfloat16"]: + return torch.bfloat16 + elif s in ["fp16", "float16"]: + return torch.float16 + elif s in ["fp32", "float32", "float"]: + return torch.float32 + elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]: + return torch.float8_e4m3fn + elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]: + return torch.float8_e4m3fnuz + elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]: + return torch.float8_e5m2 + elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]: + return torch.float8_e5m2fnuz + elif s in ["fp8", "float8"]: + return torch.float8_e4m3fn # default fp8 + else: + raise ValueError(f"Unsupported dtype: {s}") + + +def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): + """ + memory efficient save file + """ + + _TYPES = { + torch.float64: "F64", + torch.float32: "F32", + torch.float16: "F16", + torch.bfloat16: "BF16", + torch.int64: "I64", + torch.int32: "I32", + torch.int16: "I16", + torch.int8: "I8", + torch.uint8: "U8", + torch.bool: "BOOL", + getattr(torch, "float8_e5m2", None): "F8_E5M2", + getattr(torch, "float8_e4m3fn", None): "F8_E4M3", + } + _ALIGN = 256 + + def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: + validated = {} + for key, value in metadata.items(): + if not isinstance(key, str): + raise ValueError(f"Metadata key must be a string, got {type(key)}") + if not isinstance(value, str): + print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.") + validated[key] = str(value) + else: + validated[key] = value + return validated + + print(f"Using memory efficient save file: {filename}") + + header = {} + offset = 0 + if metadata: + header["__metadata__"] = validate_metadata(metadata) + for k, v in tensors.items(): + if v.numel() == 0: # empty tensor + header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]} + else: + size = v.numel() * v.element_size() + header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]} + offset += size + + hjson = json.dumps(header).encode("utf-8") + hjson += b" " * (-(len(hjson) + 8) % _ALIGN) + + with open(filename, "wb") as f: + f.write(struct.pack(" str: + return ( + f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, " + + f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, " + + f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, " + + f"unsharp_target_x={self.unsharp_target_x})" + ) + + def apply_unshark_mask(self, x: torch.Tensor): + if self.gaussian_blur_ksize is None: + return x + blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma) + # mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength) + mask = (x - blurred) * self.gaussian_blur_strength + sharpened = x + mask + return sharpened + + def interpolate(self, x: torch.Tensor, resized_size, unsharp=True): + org_dtype = x.dtype + if org_dtype == torch.bfloat16: + x = x.float() + + x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype) + + # apply unsharp mask / アンシャープマスクを適用する + if unsharp and self.gaussian_blur_ksize: + x = self.apply_unshark_mask(x) + + return x + + +class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.resized_size = None + self.gradual_latent = None + + def set_gradual_latent_params(self, size, gradual_latent: GradualLatent): + self.resized_size = size + self.gradual_latent = gradual_latent + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + + """ + + if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if not self.is_scale_input_called: + # logger.warning( + print( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + elif self.config.prediction_type == "sample": + raise NotImplementedError("prediction_type not implemented yet: sample") + else: + raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`") + + sigma_from = self.sigmas[self.step_index] + sigma_to = self.sigmas[self.step_index + 1] + sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + + dt = sigma_down - sigma + + device = model_output.device + if self.resized_size is None: + prev_sample = sample + derivative * dt + + noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( + model_output.shape, dtype=model_output.dtype, device=device, generator=generator + ) + s_noise = 1.0 + else: + print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape) + s_noise = self.gradual_latent.s_noise + + if self.gradual_latent.unsharp_target_x: + prev_sample = sample + derivative * dt + prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size) + else: + sample = self.gradual_latent.interpolate(sample, self.resized_size) + derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False) + prev_sample = sample + derivative * dt + + noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( + (model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]), + dtype=model_output.dtype, + device=device, + generator=generator, + ) + + prev_sample = prev_sample + noise * sigma_up * s_noise + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + +# endregion + +def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None): + """ + memory efficient save file + """ + + _TYPES = { + torch.float64: "F64", + torch.float32: "F32", + torch.float16: "F16", + torch.bfloat16: "BF16", + torch.int64: "I64", + torch.int32: "I32", + torch.int16: "I16", + torch.int8: "I8", + torch.uint8: "U8", + torch.bool: "BOOL", + getattr(torch, "float8_e5m2", None): "F8_E5M2", + getattr(torch, "float8_e4m3fn", None): "F8_E4M3", + } + _ALIGN = 256 + + def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]: + validated = {} + for key, value in metadata.items(): + if not isinstance(key, str): + raise ValueError(f"Metadata key must be a string, got {type(key)}") + if not isinstance(value, str): + print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.") + validated[key] = str(value) + else: + validated[key] = value + return validated + + print(f"Using memory efficient save file: {filename}") + + header = {} + offset = 0 + if metadata: + header["__metadata__"] = validate_metadata(metadata) + for k, v in tensors.items(): + if v.numel() == 0: # empty tensor + header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]} + else: + size = v.numel() * v.element_size() + header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]} + offset += size + + hjson = json.dumps(header).encode("utf-8") + hjson += b" " * (-(len(hjson) + 8) % _ALIGN) + + with open(filename, "wb") as f: + f.write(struct.pack(" b i ...", bi, inp) + inp = inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3) + + if rescale is not None: + inp = inp * rescale + + return inp - org + + +def bypass_forward_diff(org_out, *weights, constraint=None, need_transpose=False): + """### boft_bypass_forward_diff + + Args: + x (torch.Tensor): the input tensor for original model + org_out (torch.Tensor): the output tensor from original model + weights (tuple[torch.Tensor]): (oft_blocks[, rescale_weight]) + constraint (float, optional): constraint for oft + need_transpose (bool, optional): + whether to transpose the input and output, + set to `True` if the original model have "dim" not in the last axis. + For example: Convolution layers + + Returns: + torch.Tensor: output tensor + """ + oft_blocks, rescale = weights + m, num, b, _ = oft_blocks.shape + r_b = b // 2 + I = torch.eye(b, device=oft_blocks.device) + r = get_r(oft_blocks, I, constraint) + inp = org = org_out.to(dtype=r.dtype) + if need_transpose: + inp = org = inp.transpose(1, -1) + + for i in range(m): + bi = r[i] # b_num, b_size, b_size + g = 2 + k = 2**i * r_b + # ... (c g k) ->... (c k g) + # ... (d b) -> ... d b + inp = ( + inp.unflatten(-1, (-1, g, k)) + .transpose(-2, -1) + .flatten(-3) + .unflatten(-1, (-1, b)) + ) + inp = torch.einsum("b i j, ... b j -> ... b i", bi, inp) + # ... d b -> ... (d b) + # ... (c k g) -> ... (c g k) + inp = inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3) + + if rescale is not None: + inp = inp * rescale.transpose(0, -1) + + inp = inp - org + if need_transpose: + inp = inp.transpose(1, -1) + return inp diff --git a/lycoris/functional/diag_oft.py b/lycoris/functional/diag_oft.py new file mode 100644 index 0000000000000000000000000000000000000000..233a6096eaec8d6d87cccd8c774aa84fd62daf8b --- /dev/null +++ b/lycoris/functional/diag_oft.py @@ -0,0 +1,112 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .general import factorization, FUNC_LIST + + +def get_r(oft_blocks, I=None, constraint=0): + if I is None: + I = torch.eye(oft_blocks.shape[-1], device=oft_blocks.device) + if I.ndim < oft_blocks.ndim: + for _ in range(oft_blocks.ndim - I.ndim): + I = I.unsqueeze(0) + # for Q = -Q^T + q = oft_blocks - oft_blocks.transpose(-1, -2) + normed_q = q + if constraint is not None and constraint > 0: + q_norm = torch.norm(q) + 1e-8 + if q_norm > constraint: + normed_q = q * constraint / q_norm + # use float() to prevent unsupported type + r = (I + normed_q) @ (I - normed_q).float().inverse() + return r + + +def weight_gen(org_weight, max_block_size=-1, rescale=False): + """### weight_gen + + Args: + org_weight (torch.Tensor): the weight tensor + max_block_size (int): max block size + rescale (bool, optional): whether to rescale the weight. Defaults to False. + + Returns: + torch.Tensor: oft_blocks[, rescale_weight] + """ + out_dim, *rest = org_weight.shape + block_size, block_num = factorization(out_dim, max_block_size) + oft_blocks = torch.zeros(block_num, block_size, block_size) + if rescale: + return oft_blocks, torch.ones(out_dim, *[1] * len(rest)) + else: + return oft_blocks, None + + +def diff_weight(org_weight, *weights, constraint=None): + """### diff_weight + + Args: + org_weight (torch.Tensor): the weight tensor of original model + weights (tuple[torch.Tensor]): (oft_blocks[, rescale_weight]) + constraint (float, optional): constraint for oft + + Returns: + torch.Tensor: ΔW + """ + oft_blocks, rescale = weights + I = torch.eye(oft_blocks.shape[1], device=oft_blocks.device) + r = get_r(oft_blocks, I, constraint) + + block_num, block_size, _ = oft_blocks.shape + _, *shape = org_weight.shape + org_weight = org_weight.to(dtype=r.dtype) + org_weight = org_weight.view(block_num, block_size, *shape) + # Init R=0, so add I on it to ensure the output of step0 is original model output + weight = torch.einsum( + "k n m, k n ... -> k m ...", + r - I, + org_weight, + ).view(-1, *shape) + if rescale is not None: + weight = rescale * weight + weight = weight + (rescale - 1) * org_weight + return weight + + +def bypass_forward_diff(x, org_out, *weights, constraint=None, need_transpose=False): + """### bypass_forward_diff + + Args: + x (torch.Tensor): the input tensor for original model + org_out (torch.Tensor): the output tensor from original model + weights (tuple[torch.Tensor]): (oft_blocks[, rescale_weight]) + constraint (float, optional): constraint for oft + need_transpose (bool, optional): + whether to transpose the input and output, + set to `True` if the original model have "dim" not in the last axis. + For example: Convolution layers + + Returns: + torch.Tensor: output tensor + """ + oft_blocks, rescale = weights + block_num, block_size, _ = oft_blocks.shape + I = torch.eye(block_size, device=oft_blocks.device) + r = get_r(oft_blocks, I, constraint) + if need_transpose: + org_out = org_out.transpose(1, -1) + org_out = org_out.to(dtype=r.dtype) + *shape, _ = org_out.shape + oft_out = torch.einsum( + "k n m, ... k n -> ... k m", r - I, org_out.view(*shape, block_num, block_size) + ) + out = oft_out.view(*shape, -1) + if rescale is not None: + out = rescale.transpose(-1, 0) * out + out = out + (rescale - 1).transpose(-1, 0) * org_out + if need_transpose: + out = out.transpose(1, -1) + return out diff --git a/lycoris/functional/general.py b/lycoris/functional/general.py new file mode 100644 index 0000000000000000000000000000000000000000..dc0f8a04dfa6d692b518b5f1a210a25bd1098ee0 --- /dev/null +++ b/lycoris/functional/general.py @@ -0,0 +1,108 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + + +def rebuild_tucker(t, wa, wb): + rebuild2 = torch.einsum("i j ..., i p, j r -> p r ...", t, wa, wb) + return rebuild2 + + +def factorization(dimension: int, factor: int = -1) -> tuple[int, int]: + """ + return a tuple of two value of input dimension decomposed by the number closest to factor + second value is higher or equal than first value. + + In LoRA with Kroneckor Product, first value is a value for weight scale. + second value is a value for weight. + + Because of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. + + examples) + factor + -1 2 4 8 16 ... + 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 + 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 + 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 + 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 + 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 + 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 + """ + + if factor > 0 and (dimension % factor) == 0: + m = factor + n = dimension // factor + if m > n: + n, m = m, n + return m, n + if factor < 0: + factor = dimension + m, n = 1, dimension + length = m + n + while m < n: + new_m = m + 1 + while dimension % new_m != 0: + new_m += 1 + new_n = dimension // new_m + if new_m + new_n > length or new_m > factor: + break + else: + m, n = new_m, new_n + if m > n: + n, m = m, n + return m, n + + +def power2factorization(dimension: int, factor: int = -1) -> tuple[int, int]: + """ + m = 2k + n = 2**p + m*n = dim + """ + if factor == -1: + factor = dimension + + # Find the first solution and check if it is even doable + m = n = 0 + while m <= factor: + m += 2 + while dimension % m != 0 and m < dimension: + m += 2 + if m > factor: + break + if sum(int(i) for i in f"{dimension//m:b}") == 1: + n = dimension // m + + if n == 0: + return None, n + return dimension // n, n + + +def tucker_weight_from_conv(up, down, mid): + up = up.reshape(up.size(0), up.size(1)) + down = down.reshape(down.size(0), down.size(1)) + return torch.einsum("m n ..., i m, n j -> i j ...", mid, up, down) + + +def tucker_weight(wa, wb, t): + temp = torch.einsum("i j ..., j r -> i r ...", t, wb) + return torch.einsum("i j ..., i r -> r j ...", temp, wa) + + +def apply_dora_scale(org_weight, rebuild, dora_scale, scale): + dora_norm_dims = org_weight.dim() - 1 + weight = org_weight + rebuild + weight = weight.to(dora_scale.dtype) + weight_norm = ( + weight.transpose(0, 1) + .reshape(weight.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight.shape[1], *[1] * dora_norm_dims) + .transpose(0, 1) + ) + merged_scale1 = weight / weight_norm * dora_scale + diff_weight = merged_scale1 - org_weight + return org_weight + diff_weight * scale diff --git a/lycoris/functional/locon.py b/lycoris/functional/locon.py new file mode 100644 index 0000000000000000000000000000000000000000..756bbd82d7dff35d2ebc150db7a666ec423673df --- /dev/null +++ b/lycoris/functional/locon.py @@ -0,0 +1,85 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .general import rebuild_tucker, FUNC_LIST + + +def weight_gen(org_weight, rank, tucker=True): + """### weight_gen + + Args: + org_weight (torch.Tensor): the weight tensor + rank (int): low rank + + Returns: + torch.Tensor: down, up[, mid] + """ + out_dim, in_dim, *k = org_weight.shape + if k and tucker: + down = torch.empty(rank, in_dim, *(1 for _ in k)) + up = torch.empty(out_dim, rank, *(1 for _ in k)) + mid = torch.empty(rank, rank, *k) + nn.init.kaiming_uniform_(down, a=math.sqrt(5)) + nn.init.constant_(up, 0) + nn.init.kaiming_uniform_(mid, a=math.sqrt(5)) + return down, up, mid + else: + down = torch.empty(rank, in_dim) + up = torch.empty(out_dim, rank) + nn.init.kaiming_uniform_(down, a=math.sqrt(5)) + nn.init.constant_(up, 0) + return down, up, None + + +def diff_weight(*weights: tuple[torch.Tensor], gamma=1.0): + """### diff_weight + + Get ΔW = BA, where BA is low rank decomposition + + Args: + weights (tuple[torch.Tensor]): (down, up[, mid]) + gamma (float, optional): scale factor, normally alpha/rank here + + Returns: + torch.Tensor: ΔW + """ + d, u, m = weights + R, I, *k = d.shape + O, R, *_ = u.shape + u = u * gamma + + if m is None: + result = u.reshape(-1, u.size(1)) @ d.reshape(d.size(0), -1) + else: + R, R, *k = m.shape + u = u.reshape(u.size(0), -1).transpose(0, 1) + d = d.reshape(d.size(0), -1) + result = rebuild_tucker(m, u, d) + return result.reshape(O, I, *k) + + +def bypass_forward_diff(x, org_out, *weights, gamma=1.0, extra_args={}): + """### bypass_forward_diff + + Args: + x (torch.Tensor): input tensor + weights (tuple[torch.Tensor]): (down, up[, mid]) + gamma (float, optional): scale factor, normally alpha/rank here + extra_args (dict, optional): extra args for forward func, \ + e.g. padding, stride for Conv1/2/3d + + Returns: + torch.Tensor: output tensor + """ + d, u, m = weights + if m is not None: + down = FUNC_LIST[d.dim()](x, d) + mid = FUNC_LIST[d.dim()](down, m, **extra_args) + up = FUNC_LIST[d.dim()](mid, u) + else: + down = FUNC_LIST[d.dim()](x, d, **extra_args) + up = FUNC_LIST[d.dim()](down, u) + return up * gamma diff --git a/lycoris/functional/loha.py b/lycoris/functional/loha.py new file mode 100644 index 0000000000000000000000000000000000000000..042e56bdeedb84592949a112c109b777d5f8f27c --- /dev/null +++ b/lycoris/functional/loha.py @@ -0,0 +1,165 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .general import FUNC_LIST + + +class HadaWeight(torch.autograd.Function): + @staticmethod + def forward(ctx, w1d, w1u, w2d, w2u, scale=torch.tensor(1)): + ctx.save_for_backward(w1d, w1u, w2d, w2u, scale) + diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale + return diff_weight + + @staticmethod + def backward(ctx, grad_out): + (w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors + grad_out = grad_out * scale + temp = grad_out * (w2u @ w2d) + grad_w1u = temp @ w1d.T + grad_w1d = w1u.T @ temp + + temp = grad_out * (w1u @ w1d) + grad_w2u = temp @ w2d.T + grad_w2d = w2u.T @ temp + + del temp + return grad_w1d, grad_w1u, grad_w2d, grad_w2u, None + + +class HadaWeightTucker(torch.autograd.Function): + @staticmethod + def forward(ctx, t1, w1d, w1u, t2, w2d, w2u, scale=torch.tensor(1)): + ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale) + + rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u) + rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u) + + return rebuild1 * rebuild2 * scale + + @staticmethod + def backward(ctx, grad_out): + (t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors + grad_out = grad_out * scale + + temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d) + rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u) + + grad_w = rebuild * grad_out + del rebuild + + grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w) + grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T) + del grad_w, temp + + grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp) + grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T) + del grad_temp + + temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d) + rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u) + + grad_w = rebuild * grad_out + del rebuild + + grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w) + grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T) + del grad_w, temp + + grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp) + grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T) + del grad_temp + return grad_t1, grad_w1d, grad_w1u, grad_t2, grad_w2d, grad_w2u, None + + +def make_weight(w1d, w1u, w2d, w2u, scale): + return HadaWeight.apply(w1d, w1u, w2d, w2u, scale) + + +def make_weight_tucker(t1, w1d, w1u, t2, w2d, w2u, scale): + return HadaWeightTucker.apply(t1, w1d, w1u, t2, w2d, w2u, scale) + + +def weight_gen(org_weight, rank, tucker=True): + """### weight_gen + + Args: + org_weight (torch.Tensor): the weight tensor + rank (int): low rank + + Returns: + torch.Tensor: w1d, w2d, w1u, w2u[, t1, t2] + """ + out_dim, in_dim, *k = org_weight.shape + if k and tucker: + w1d = torch.empty(rank, in_dim) + w1u = torch.empty(rank, out_dim) + t1 = torch.empty(rank, rank, *k) + w2d = torch.empty(rank, in_dim) + w2u = torch.empty(rank, out_dim) + t2 = torch.empty(rank, rank, *k) + nn.init.normal_(t1, std=0.1) + nn.init.normal_(t2, std=0.1) + else: + w1d = torch.empty(rank, in_dim) + w1u = torch.empty(out_dim, rank) + w2d = torch.empty(rank, in_dim) + w2u = torch.empty(out_dim, rank) + t1 = t2 = None + nn.init.normal_(w1d, std=1) + nn.init.constant_(w1u, 0) + nn.init.normal_(w2d, std=1) + nn.init.normal_(w2u, std=0.1) + return w1d, w1u, w2d, w2u, t1, t2 + + +def diff_weight(*weights, gamma=1.0): + """### diff_weight + + Get ΔW = BA, where BA is low rank decomposition + + Args: + wegihts (tuple[torch.Tensor]): (w1d, w2d, w1u, w2u[, t1, t2]) + gamma (float, optional): scale factor, normally alpha/rank here + + Returns: + torch.Tensor: ΔW + """ + w1d, w1u, w2d, w2u, t1, t2 = weights + if t1 is not None and t2 is not None: + R, I = w1d.shape + R, O = w1u.shape + R, R, *k = t1.shape + result = make_weight_tucker(t1, w1d, w1u, t2, w2d, w2u, gamma) + else: + R, I, *k = w1d.shape + O, R, *_ = w1u.shape + w1d = w1d.reshape(w1d.size(0), -1) + w1u = w1u.reshape(-1, w1u.size(1)) + w2d = w2d.reshape(w2d.size(0), -1) + w2u = w2u.reshape(-1, w2u.size(1)) + result = make_weight(w1d, w1u, w2d, w2u, gamma) + + result = result.reshape(O, I, *k) + return result + + +def bypass_forward_diff(x, org_out, *weights, gamma=1.0, extra_args={}): + """### bypass_forward_diff + + Args: + x (torch.Tensor): input tensor + weights (tuple[torch.Tensor]): (w1d, w2d, w1u, w2u[, t1, t2]) + gamma (float, optional): scale factor, normally alpha/rank here + extra_args (dict, optional): extra args for forward func, \ + e.g. padding, stride for Conv1/2/3d + + Returns: + torch.Tensor: output tensor + """ + w1d, w1u, w2d, w2u, t1, t2 = weights + diff_w = diff_weight(w1d, w1u, w2d, w2u, t1, t2, gamma) + return FUNC_LIST[w1d.dim() if t1 is None else t1.dim()](x, diff_w, **extra_args) diff --git a/lycoris/functional/lokr.py b/lycoris/functional/lokr.py new file mode 100644 index 0000000000000000000000000000000000000000..75720ed61a9aa5bf0a45008fc9bca71a7373fd2b --- /dev/null +++ b/lycoris/functional/lokr.py @@ -0,0 +1,247 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .general import rebuild_tucker, FUNC_LIST +from .general import factorization + + +def make_kron(w1, w2, scale): + for _ in range(w2.dim() - w1.dim()): + w1 = w1.unsqueeze(-1) + w2 = w2.contiguous() + rebuild = torch.kron(w1, w2) + + if scale != 1: + rebuild = rebuild * scale + + return rebuild + + +def weight_gen( + org_weight, + rank, + tucker=True, + factor=-1, + decompose_both=False, + full_matrix=False, + unbalanced_factorization=False, +): + """### weight_gen + + Args: + org_weight (torch.Tensor): the weight tensor + rank (int): low rank + + Returns: + torch.Tensor | None: w1, w1a, w1b, w2, w2a, w2b, t2 + """ + out_dim, in_dim, *k = org_weight.shape + w1 = w1a = w1b = None + w2 = w2a = w2b = None + t2 = None + use_w1 = use_w2 = False + + if k: + k_size = k + shape = (out_dim, in_dim, *k_size) + + in_m, in_n = factorization(in_dim, factor) + out_l, out_k = factorization(out_dim, factor) + if unbalanced_factorization: + out_l, out_k = out_k, out_l + shape = ((out_l, out_k), (in_m, in_n), *k_size) # ((a, b), (c, d), *k_size) + tucker = tucker and any(i != 1 for i in k_size) + if ( + decompose_both + and rank < max(shape[0][0], shape[1][0]) / 2 + and not full_matrix + ): + w1a = torch.empty(shape[0][0], rank) + w1b = torch.empty(rank, shape[1][0]) + else: + use_w1 = True + w1 = torch.empty(shape[0][0], shape[1][0]) # a*c, 1-mode + + if rank >= max(shape[0][1], shape[1][1]) / 2 or full_matrix: + use_w2 = True + w2 = torch.empty(shape[0][1], shape[1][1], *k_size) + elif tucker: + t2 = torch.empty(rank, rank, *shape[2:]) + w2a = torch.empty(rank, shape[0][1]) # b, 1-mode + w2b = torch.empty(rank, shape[1][1]) # d, 2-mode + else: # Conv2d not tucker + # bigger part. weight and LoRA. [b, dim] x [dim, d*k1*k2] + w2a = torch.empty(shape[0][1], rank) + w2b = torch.empty(rank, shape[1][1], *shape[2:]) + # w1 ⊗ (w2a x w2b) = (a, b)⊗((c, dim)x(dim, d*k1*k2)) = (a, b)⊗(c, d*k1*k2) = (ac, bd*k1*k2) + else: # Linear + shape = (out_dim, in_dim) + + in_m, in_n = factorization(in_dim, factor) + out_l, out_k = factorization(out_dim, factor) + if unbalanced_factorization: + out_l, out_k = out_k, out_l + shape = ( + (out_l, out_k), + (in_m, in_n), + ) # ((a, b), (c, d)), out_dim = a*c, in_dim = b*d + # smaller part. weight scale + if decompose_both and rank < max(shape[0][0], shape[1][0]) / 2: + w1a = torch.empty(shape[0][0], rank) + w1b = torch.empty(rank, shape[1][0]) + else: + use_w1 = True + w1 = torch.empty(shape[0][0], shape[1][0]) # a*c, 1-mode + if rank < max(shape[0][1], shape[1][1]) / 2: + # bigger part. weight and LoRA. [b, dim] x [dim, d] + w2a = torch.empty(shape[0][1], rank) + w2b = torch.empty(rank, shape[1][1]) + # w1 ⊗ (w2a x w2b) = (a, b)⊗((c, dim)x(dim, d)) = (a, b)⊗(c, d) = (ac, bd) + else: + use_w2 = True + w2 = torch.empty(shape[0][1], shape[1][1]) + + if use_w2: + torch.nn.init.constant_(w2, 1) + else: + if tucker: + torch.nn.init.kaiming_uniform_(t2, a=math.sqrt(5)) + torch.nn.init.kaiming_uniform_(w2a, a=math.sqrt(5)) + torch.nn.init.constant_(w2b, 1) + + if use_w1: + torch.nn.init.kaiming_uniform_(w1, a=math.sqrt(5)) + else: + torch.nn.init.kaiming_uniform_(w1a, a=math.sqrt(5)) + torch.nn.init.kaiming_uniform_(w1b, a=math.sqrt(5)) + + return w1, w1a, w1b, w2, w2a, w2b, t2 + + +def diff_weight(*weights, gamma=1.0): + """### diff_weight + + Args: + weights (tuple[torch.Tensor]): (w1, w1a, w1b, w2, w2a, w2b, t) + gamma (float, optional): scale factor, normally alpha/rank here + + Returns: + torch.Tensor: ΔW + """ + w1, w1a, w1b, w2, w2a, w2b, t = weights + if w1a is not None: + rank = w1a.shape[1] + elif w2a is not None: + rank = w2a.shape[1] + else: + rank = gamma + scale = gamma / rank + if w1 is None: + w1 = w1a @ w1b + if w2 is None: + if t is None: + r, o, *k = w2b.shape + w2 = w2a @ w2b.view(r, -1) + w2 = w2.view(-1, o, *k) + else: + w2 = rebuild_tucker(t, w2a, w2b) + return make_kron(w1, w2, scale) + + +def bypass_forward_diff(h, org_out, *weights, gamma=1.0, extra_args={}): + """### bypass_forward_diff + + Args: + weights (tuple[torch.Tensor]): (w1, w1a, w1b, w2, w2a, w2b, t) + gamma (float, optional): scale factor, normally alpha/rank here + extra_args (dict, optional): extra args for forward func, \ + e.g. padding, stride for Conv1/2/3d + + Returns: + torch.Tensor: output tensor + """ + w1, w1a, w1b, w2, w2a, w2b, t = weights + use_w1 = w1 is not None + use_w2 = w2 is not None + tucker = t is not None + dim = t.dim() if tucker else w2.dim() if w2 is not None else w2b.dim() + rank = w1b.size(0) if not use_w1 else w2b.size(0) if not use_w2 else gamma + scale = gamma / rank + is_conv = dim > 2 + op = FUNC_LIST[dim] + + if is_conv: + kw_dict = extra_args + else: + kw_dict = {} + + if use_w2: + ba = w2 + else: + a = w2b + b = w2a + + if t is not None: + a = a.view(*a.shape, *[1] * (dim - 2)) + b = b.view(*b.shape, *[1] * (dim - 2)) + elif is_conv: + b = b.view(*b.shape, *[1] * (dim - 2)) + + if use_w1: + c = w1 + else: + c = w1a @ w1b + uq = c.size(1) + + if is_conv: + # (b, uq), vq, ... + B, _, *rest = h.shape + h_in_group = h.reshape(B * uq, -1, *rest) + else: + # b, ..., uq, vq + h_in_group = h.reshape(*h.shape[:-1], uq, -1) + + if use_w2: + hb = op(h_in_group, ba, **kw_dict) + else: + if is_conv: + if tucker: + ha = op(h_in_group, a) + ht = op(ha, t, **kw_dict) + hb = op(ht, b) + else: + ha = op(h_in_group, a, **kw_dict) + hb = op(ha, b) + else: + ha = op(h_in_group, a, **kw_dict) + hb = op(ha, b) + + if is_conv: + # (b, uq), vp, ..., f + # -> b, uq, vp, ..., f + # -> b, f, vp, ..., uq + hb = hb.view(B, -1, *hb.shape[1:]) + h_cross_group = hb.transpose(1, -1) + else: + # b, ..., uq, vq + # -> b, ..., vq, uq + h_cross_group = hb.transpose(-1, -2) + + hc = F.linear(h_cross_group, c) + if is_conv: + # b, f, vp, ..., up + # -> b, up, vp, ... ,f + # -> b, c, ..., f + hc = hc.transpose(1, -1) + h = hc.reshape(B, -1, *hc.shape[3:]) + else: + # b, ..., vp, up + # -> b, ..., up, vp + # -> b, ..., c + hc = hc.transpose(-1, -2) + h = hc.reshape(*hc.shape[:-2], -1) + + return h * scale diff --git a/lycoris/kohya.py b/lycoris/kohya.py new file mode 100644 index 0000000000000000000000000000000000000000..3f31872b0fd6fe944df7fd30718618a4675e0290 --- /dev/null +++ b/lycoris/kohya.py @@ -0,0 +1,676 @@ +import os +import fnmatch +import re +import logging + +from typing import Any, List + +import torch + +from .utils import precalculate_safetensors_hashes +from .wrapper import LycorisNetwork, network_module_dict, deprecated_arg_dict +from .modules.locon import LoConModule +from .modules.loha import LohaModule +from .modules.ia3 import IA3Module +from .modules.lokr import LokrModule +from .modules.dylora import DyLoraModule +from .modules.glora import GLoRAModule +from .modules.norms import NormModule +from .modules.full import FullModule +from .modules.diag_oft import DiagOFTModule +from .modules.boft import ButterflyOFTModule +from .modules import make_module, get_module + +from .config import PRESET +from .utils.preset import read_preset +from .utils import str_bool +from .logging import logger + + +def create_network( + multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs +): + for key, value in list(kwargs.items()): + if key in deprecated_arg_dict: + logger.warning( + f"{key} is deprecated. Please use {deprecated_arg_dict[key]} instead.", + stacklevel=2, + ) + kwargs[deprecated_arg_dict[key]] = value + if network_dim is None: + network_dim = 4 # default + conv_dim = int(kwargs.get("conv_dim", network_dim) or network_dim) + conv_alpha = float(kwargs.get("conv_alpha", network_alpha) or network_alpha) + dropout = float(kwargs.get("dropout", 0.0) or 0.0) + rank_dropout = float(kwargs.get("rank_dropout", 0.0) or 0.0) + module_dropout = float(kwargs.get("module_dropout", 0.0) or 0.0) + algo = (kwargs.get("algo", "lora") or "lora").lower() + use_tucker = str_bool( + not kwargs.get("disable_conv_cp", True) + or kwargs.get("use_conv_cp", False) + or kwargs.get("use_cp", False) + or kwargs.get("use_tucker", False) + ) + use_scalar = str_bool(kwargs.get("use_scalar", False)) + block_size = int(kwargs.get("block_size", None) or 4) + train_norm = str_bool(kwargs.get("train_norm", False)) + constraint = float(kwargs.get("constraint", None) or 0) + rescaled = str_bool(kwargs.get("rescaled", False)) + weight_decompose = str_bool(kwargs.get("dora_wd", False)) + wd_on_output = str_bool(kwargs.get("wd_on_output", False)) + full_matrix = str_bool(kwargs.get("full_matrix", False)) + bypass_mode = str_bool(kwargs.get("bypass_mode", None)) + rs_lora = str_bool(kwargs.get("rs_lora", False)) + unbalanced_factorization = str_bool(kwargs.get("unbalanced_factorization", False)) + train_t5xxl = str_bool(kwargs.get("train_t5xxl", False)) + + if unbalanced_factorization: + logger.info("Unbalanced factorization for LoKr is enabled") + + if bypass_mode: + logger.info("Bypass mode is enabled") + + if weight_decompose: + logger.info("Weight decomposition is enabled") + + if full_matrix: + logger.info("Full matrix mode for LoKr is enabled") + + preset_str = kwargs.get("preset", "full") + if preset_str not in PRESET: + preset = read_preset(preset_str) + else: + preset = PRESET[preset_str] + assert preset is not None + LycorisNetworkKohya.apply_preset(preset) + + logger.info(f"Using rank adaptation algo: {algo}") + + if algo == "ia3" and preset_str != "ia3": + logger.warning("It is recommended to use preset ia3 for IA^3 algorithm") + + network = LycorisNetworkKohya( + text_encoder, + unet, + multiplier=multiplier, + lora_dim=network_dim, + conv_lora_dim=conv_dim, + alpha=network_alpha, + conv_alpha=conv_alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + use_tucker=use_tucker, + use_scalar=use_scalar, + network_module=algo, + train_norm=train_norm, + decompose_both=kwargs.get("decompose_both", False), + factor=kwargs.get("factor", -1), + block_size=block_size, + constraint=constraint, + rescaled=rescaled, + weight_decompose=weight_decompose, + wd_on_out=wd_on_output, + full_matrix=full_matrix, + bypass_mode=bypass_mode, + rs_lora=rs_lora, + unbalanced_factorization=unbalanced_factorization, + train_t5xxl=train_t5xxl, + ) + + return network + + +def create_network_from_weights( + multiplier, + file, + vae, + text_encoder, + unet, + weights_sd=None, + for_inference=False, + **kwargs, +): + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping + unet_loras = {} + te_loras = {} + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if lora_name.startswith(LycorisNetworkKohya.LORA_PREFIX_UNET): + unet_loras[lora_name] = None + elif lora_name.startswith(LycorisNetworkKohya.LORA_PREFIX_TEXT_ENCODER): + te_loras[lora_name] = None + + for name, modules in unet.named_modules(): + lora_name = f"{LycorisNetworkKohya.LORA_PREFIX_UNET}_{name}".replace(".", "_") + if lora_name in unet_loras: + unet_loras[lora_name] = modules + + if isinstance(text_encoder, list): + text_encoders = text_encoder + use_index = True + else: + text_encoders = [text_encoder] + use_index = False + + for idx, te in enumerate(text_encoders): + if use_index: + prefix = f"{LycorisNetworkKohya.LORA_PREFIX_TEXT_ENCODER}{idx+1}" + else: + prefix = LycorisNetworkKohya.LORA_PREFIX_TEXT_ENCODER + for name, modules in te.named_modules(): + lora_name = f"{prefix}_{name}".replace(".", "_") + if lora_name in te_loras: + te_loras[lora_name] = modules + + original_level = logger.level + logger.setLevel(logging.ERROR) + network = LycorisNetworkKohya(text_encoder, unet) + network.unet_loras = [] + network.text_encoder_loras = [] + logger.setLevel(original_level) + + logger.info("Loading UNet Modules from state dict...") + for lora_name, orig_modules in unet_loras.items(): + if orig_modules is None: + continue + lyco_type, params = get_module(weights_sd, lora_name) + module = make_module(lyco_type, params, lora_name, orig_modules) + if module is not None: + network.unet_loras.append(module) + logger.info(f"{len(network.unet_loras)} Modules Loaded") + + logger.info("Loading TE Modules from state dict...") + for lora_name, orig_modules in te_loras.items(): + if orig_modules is None: + continue + lyco_type, params = get_module(weights_sd, lora_name) + module = make_module(lyco_type, params, lora_name, orig_modules) + if module is not None: + network.text_encoder_loras.append(module) + logger.info(f"{len(network.text_encoder_loras)} Modules Loaded") + + for lora in network.unet_loras + network.text_encoder_loras: + lora.multiplier = multiplier + + return network, weights_sd + + +class LycorisNetworkKohya(LycorisNetwork): + """ + LoRA + LoCon + """ + + # Ignore proj_in or proj_out, their channels is only a few. + ENABLE_CONV = True + UNET_TARGET_REPLACE_MODULE = [ + "Transformer2DModel", + "ResnetBlock2D", + "Downsample2D", + "Upsample2D", + "HunYuanDiTBlock", + "DoubleStreamBlock", + "SingleStreamBlock", + "SingleDiTBlock", + "MMDoubleStreamBlock", #HunYuanVideo + "MMSingleStreamBlock", #HunYuanVideo + ] + UNET_TARGET_REPLACE_NAME = [ + "conv_in", + "conv_out", + "time_embedding.linear_1", + "time_embedding.linear_2", + ] + TEXT_ENCODER_TARGET_REPLACE_MODULE = [ + "CLIPAttention", + "CLIPSdpaAttention", + "CLIPMLP", + "MT5Block", + "BertLayer", + ] + TEXT_ENCODER_TARGET_REPLACE_NAME = [] + LORA_PREFIX_UNET = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + MODULE_ALGO_MAP = {} + NAME_ALGO_MAP = {} + USE_FNMATCH = False + + @classmethod + def apply_preset(cls, preset): + if "enable_conv" in preset: + cls.ENABLE_CONV = preset["enable_conv"] + if "unet_target_module" in preset: + cls.UNET_TARGET_REPLACE_MODULE = preset["unet_target_module"] + if "unet_target_name" in preset: + cls.UNET_TARGET_REPLACE_NAME = preset["unet_target_name"] + if "text_encoder_target_module" in preset: + cls.TEXT_ENCODER_TARGET_REPLACE_MODULE = preset[ + "text_encoder_target_module" + ] + if "text_encoder_target_name" in preset: + cls.TEXT_ENCODER_TARGET_REPLACE_NAME = preset["text_encoder_target_name"] + if "module_algo_map" in preset: + cls.MODULE_ALGO_MAP = preset["module_algo_map"] + if "name_algo_map" in preset: + cls.NAME_ALGO_MAP = preset["name_algo_map"] + if "use_fnmatch" in preset: + cls.USE_FNMATCH = preset["use_fnmatch"] + return cls + + def __init__( + self, + text_encoder, + unet, + multiplier=1.0, + lora_dim=4, + conv_lora_dim=4, + alpha=1, + conv_alpha=1, + use_tucker=False, + dropout=0, + rank_dropout=0, + module_dropout=0, + network_module: str = "locon", + norm_modules=NormModule, + train_norm=False, + train_t5xxl=False, + **kwargs, + ) -> None: + torch.nn.Module.__init__(self) + root_kwargs = kwargs + self.multiplier = multiplier + self.lora_dim = lora_dim + self.train_t5xxl = train_t5xxl + + if not self.ENABLE_CONV: + conv_lora_dim = 0 + + self.conv_lora_dim = int(conv_lora_dim) + if self.conv_lora_dim and self.conv_lora_dim != self.lora_dim: + logger.info("Apply different lora dim for conv layer") + logger.info(f"Conv Dim: {conv_lora_dim}, Linear Dim: {lora_dim}") + elif self.conv_lora_dim == 0: + logger.info("Disable conv layer") + + self.alpha = alpha + self.conv_alpha = float(conv_alpha) + if self.conv_lora_dim and self.alpha != self.conv_alpha: + logger.info("Apply different alpha value for conv layer") + logger.info(f"Conv alpha: {conv_alpha}, Linear alpha: {alpha}") + + if 1 >= dropout >= 0: + logger.info(f"Use Dropout value: {dropout}") + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + self.use_tucker = use_tucker + + def create_single_module( + lora_name: str, + module: torch.nn.Module, + algo_name, + dim=None, + alpha=None, + use_tucker=self.use_tucker, + **kwargs, + ): + for k, v in root_kwargs.items(): + if k in kwargs: + continue + kwargs[k] = v + + if train_norm and "Norm" in module.__class__.__name__: + return norm_modules( + lora_name, + module, + self.multiplier, + self.rank_dropout, + self.module_dropout, + **kwargs, + ) + lora = None + if isinstance(module, torch.nn.Linear) and lora_dim > 0: + dim = dim or lora_dim + alpha = alpha or self.alpha + elif isinstance( + module, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d) + ): + k_size, *_ = module.kernel_size + if k_size == 1 and lora_dim > 0: + dim = dim or lora_dim + alpha = alpha or self.alpha + elif conv_lora_dim > 0 or dim: + dim = dim or conv_lora_dim + alpha = alpha or self.conv_alpha + else: + return None + else: + return None + lora = network_module_dict[algo_name]( + lora_name, + module, + self.multiplier, + dim, + alpha, + self.dropout, + self.rank_dropout, + self.module_dropout, + use_tucker, + **kwargs, + ) + return lora + + def create_modules_( + prefix: str, + root_module: torch.nn.Module, + algo, + configs={}, + ): + loras = {} + lora_names = [] + for name, module in root_module.named_modules(): + module_name = module.__class__.__name__ + if module_name in self.MODULE_ALGO_MAP and module is not root_module: + next_config = self.MODULE_ALGO_MAP[module_name] + next_algo = next_config.get("algo", algo) + new_loras, new_lora_names = create_modules_( + f"{prefix}_{name}", module, next_algo, next_config + ) + for lora_name, lora in zip(new_lora_names, new_loras): + if lora_name not in loras: + loras[lora_name] = lora + lora_names.append(lora_name) + continue + if name: + lora_name = prefix + "." + name + else: + lora_name = prefix + lora_name = lora_name.replace(".", "_") + if lora_name in loras: + continue + + lora = create_single_module(lora_name, module, algo, **configs) + if lora is not None: + loras[lora_name] = lora + lora_names.append(lora_name) + return [loras[lora_name] for lora_name in lora_names], lora_names + + # create module instances + def create_modules( + prefix, + root_module: torch.nn.Module, + target_replace_modules, + target_replace_names=[], + ) -> List: + logger.info("Create LyCORIS Module") + loras = [] + next_config = {} + for name, module in root_module.named_modules(): + module_name = module.__class__.__name__ + if module_name in target_replace_modules and not any( + self.match_fn(t, name) for t in target_replace_names + ): + if module_name in self.MODULE_ALGO_MAP: + next_config = self.MODULE_ALGO_MAP[module_name] + algo = next_config.get("algo", network_module) + else: + algo = network_module + loras.extend( + create_modules_(f"{prefix}_{name}", module, algo, next_config)[ + 0 + ] + ) + next_config = {} + elif name in target_replace_names or any( + self.match_fn(t, name) for t in target_replace_names + ): + conf_from_name = self.find_conf_for_name(name) + if conf_from_name is not None: + next_config = conf_from_name + algo = next_config.get("algo", network_module) + elif module_name in self.MODULE_ALGO_MAP: + next_config = self.MODULE_ALGO_MAP[module_name] + algo = next_config.get("algo", network_module) + else: + algo = network_module + lora_name = prefix + "." + name + lora_name = lora_name.replace(".", "_") + lora = create_single_module(lora_name, module, algo, **next_config) + next_config = {} + if lora is not None: + loras.append(lora) + return loras + + if network_module == GLoRAModule: + logger.info("GLoRA enabled, only train transformer") + # only train transformer (for GLoRA) + LycorisNetworkKohya.UNET_TARGET_REPLACE_MODULE = [ + "Transformer2DModel", + "Attention", + ] + LycorisNetworkKohya.UNET_TARGET_REPLACE_NAME = [] + + self.text_encoder_loras = [] + if text_encoder: + if isinstance(text_encoder, list): + text_encoders = text_encoder + use_index = True + else: + text_encoders = [text_encoder] + use_index = False + + for i, te in enumerate(text_encoders): + self.text_encoder_loras.extend( + create_modules( + LycorisNetworkKohya.LORA_PREFIX_TEXT_ENCODER + + (f"{i+1}" if use_index else ""), + te, + LycorisNetworkKohya.TEXT_ENCODER_TARGET_REPLACE_MODULE, + LycorisNetworkKohya.TEXT_ENCODER_TARGET_REPLACE_NAME, + ) + ) + logger.info( + f"create LyCORIS for Text Encoder: {len(self.text_encoder_loras)} modules." + ) + + self.unet_loras = create_modules( + LycorisNetworkKohya.LORA_PREFIX_UNET, + unet, + LycorisNetworkKohya.UNET_TARGET_REPLACE_MODULE, + LycorisNetworkKohya.UNET_TARGET_REPLACE_NAME, + ) + logger.info(f"create LyCORIS for U-Net: {len(self.unet_loras)} modules.") + + algo_table = {} + for lora in self.text_encoder_loras + self.unet_loras: + algo_table[lora.__class__.__name__] = ( + algo_table.get(lora.__class__.__name__, 0) + 1 + ) + logger.info(f"module type table: {algo_table}") + + self.weights_sd = None + + self.loras = self.text_encoder_loras + self.unet_loras + # assertion + names = set() + for lora in self.loras: + assert ( + lora.lora_name not in names + ), f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def match_fn(self, pattern: str, name: str) -> bool: + if self.USE_FNMATCH: + return fnmatch.fnmatch(name, pattern) + return re.match(pattern, name) + + def find_conf_for_name( + self, + name: str, + ) -> dict[str, Any]: + if name in self.NAME_ALGO_MAP.keys(): + return self.NAME_ALGO_MAP[name] + + for key, value in self.NAME_ALGO_MAP.items(): + if self.match_fn(key, name): + return value + + return None + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + self.weights_sd = load_file(file) + else: + self.weights_sd = torch.load(file, map_location="cpu") + missing, unexpected = self.load_state_dict(self.weights_sd, strict=False) + state = {} + if missing: + state["missing keys"] = missing + if unexpected: + state["unexpected keys"] = unexpected + return state + + def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None): + assert ( + apply_text_encoder is not None and apply_unet is not None + ), f"internal error: flag not set" + + if apply_text_encoder: + logger.info("enable LyCORIS for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable LyCORIS for U-Net") + else: + self.unet_loras = [] + + self.loras = self.text_encoder_loras + self.unet_loras + + for lora in self.loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + if self.weights_sd: + # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros) + info = self.load_state_dict(self.weights_sd, False) + logger.info(f"weights are loaded: {info}") + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoder, unet, weights_sd, dtype, device): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if key.startswith(LycorisNetworkKohya.LORA_PREFIX_TEXT_ENCODER): + apply_text_encoder = True + elif key.startswith(LycorisNetworkKohya.LORA_PREFIX_UNET): + apply_unet = True + + if apply_text_encoder: + logger.info("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable LoRA for U-Net") + else: + self.unet_loras = [] + + self.loras = self.text_encoder_loras + self.unet_loras + super().merge_to(1) + + def apply_max_norm_regularization(self, max_norm_value, device): + key_scaled = 0 + norms = [] + for module in self.unet_loras + self.text_encoder_loras: + scaled, norm = module.apply_max_norm(max_norm_value, device) + if scaled is None: + continue + norms.append(norm) + key_scaled += scaled + + if key_scaled == 0: + return 0, 0, 0 + + return key_scaled, sum(norms) / len(norms), max(norms) + + def prepare_optimizer_params(self, text_encoder_lr=None, unet_lr: float = 1e-4, learning_rate=None): + def enumerate_params(loras): + params = [] + for lora in loras: + params.extend(lora.parameters()) + return params + + self.requires_grad_(True) + all_params = [] + lr_descriptions = [] + + if self.text_encoder_loras: + param_data = {"params": enumerate_params(self.text_encoder_loras)} + if text_encoder_lr is not None: + param_data["lr"] = text_encoder_lr + all_params.append(param_data) + lr_descriptions.append("text_encoder") + + if self.unet_loras: + param_data = {"params": enumerate_params(self.unet_loras)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) + lr_descriptions.append("unet") + + return all_params, lr_descriptions + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + #def on_step_start(self): + # pass + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash = precalculate_safetensors_hashes(state_dict) + metadata["sshs_model_hash"] = model_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) diff --git a/lycoris/logging.py b/lycoris/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..6c51ccca3ab40b9040f2900c9dc7cf3cde5c1d2a --- /dev/null +++ b/lycoris/logging.py @@ -0,0 +1,52 @@ +import sys +import copy +import logging +from functools import cache + + +class ColoredFormatter(logging.Formatter): + COLORS = { + "DEBUG": "\033[0;36m", # CYAN + "INFO": "\033[0;32m", # GREEN + "WARNING": "\033[0;33m", # YELLOW + "ERROR": "\033[0;31m", # RED + "CRITICAL": "\033[0;37;41m", # WHITE ON RED + "RESET": "\033[0m", # RESET COLOR + } + + def format(self, record): + colored_record = copy.copy(record) + levelname = colored_record.levelname + seq = self.COLORS.get(levelname, self.COLORS["RESET"]) + colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}" + return super().format(colored_record) + + +logger = logging.getLogger("LyCORIS") +logger.propagate = False +logger.setLevel(logging.INFO) + + +if not logger.handlers: + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + ColoredFormatter( + "%(asctime)s|[%(name)s]-%(levelname)s: %(message)s", "%Y-%m-%d %H:%M:%S" + ) + ) + logger.addHandler(handler) + + +@cache +def info_once(msg): + logger.info(msg) + + +@cache +def warning_once(msg): + logger.warning(msg) + + +@cache +def error_once(msg): + logger.error(msg) diff --git a/lycoris/modules/__init__.py b/lycoris/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d0c8bde2ef50c0f9499e92198135e64d772c349 --- /dev/null +++ b/lycoris/modules/__init__.py @@ -0,0 +1,46 @@ +import torch +import torch.nn as nn + +from .base import LycorisBaseModule +from .locon import LoConModule +from .loha import LohaModule +from .lokr import LokrModule +from .full import FullModule +from .norms import NormModule +from .diag_oft import DiagOFTModule +from .boft import ButterflyOFTModule +from .glora import GLoRAModule +from .dylora import DyLoraModule +from .ia3 import IA3Module + +from ..functional.general import factorization + + +MODULE_LIST = [ + LoConModule, + LohaModule, + IA3Module, + LokrModule, + FullModule, + NormModule, + DiagOFTModule, + ButterflyOFTModule, + GLoRAModule, + DyLoraModule, +] + + +def get_module(lyco_state_dict, lora_name): + for module in MODULE_LIST: + if module.algo_check(lyco_state_dict, lora_name): + return module, tuple(module.extract_state_dict(lyco_state_dict, lora_name)) + return None, None + + +@torch.no_grad() +def make_module(lyco_type: LycorisBaseModule, params, lora_name, orig_module): + try: + module = lyco_type.make_module_from_state_dict(lora_name, orig_module, *params) + except NotImplementedError: + module = None + return module diff --git a/lycoris/modules/base.py b/lycoris/modules/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e8dec35cadde64848e9f3dad0929550268352561 --- /dev/null +++ b/lycoris/modules/base.py @@ -0,0 +1,315 @@ +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.utils.parametrize as parametrize + +from ..utils.quant import QuantLinears, log_bypass, log_suspect + + +class ModuleCustomSD(nn.Module): + def __init__(self): + super().__init__() + self._register_load_state_dict_pre_hook(self.load_weight_prehook) + self.register_load_state_dict_post_hook(self.load_weight_hook) + + def load_weight_prehook( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + pass + + def load_weight_hook(self, module, incompatible_keys): + pass + + def custom_state_dict(self): + return None + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + # TODO: Remove `args` and the parsing logic when BC allows. + if len(args) > 0: + if destination is None: + destination = args[0] + if len(args) > 1 and prefix == "": + prefix = args[1] + if len(args) > 2 and keep_vars is False: + keep_vars = args[2] + # DeprecationWarning is ignored by default + + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + + local_metadata = dict(version=self._version) + if hasattr(destination, "_metadata"): + destination._metadata[prefix[:-1]] = local_metadata + + if (custom_sd := self.custom_state_dict()) is not None: + for k, v in custom_sd.items(): + destination[f"{prefix}{k}"] = v + return destination + else: + return super().state_dict( + *args, destination=destination, prefix=prefix, keep_vars=keep_vars + ) + + +class LycorisBaseModule(ModuleCustomSD): + name: str + dtype_tensor: torch.Tensor + support_module = {} + weight_list = [] + weight_list_det = [] + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + dropout=0.0, + rank_dropout=0.0, + module_dropout=0.0, + rank_dropout_scale=False, + bypass_mode=None, + **kwargs, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + self.lora_name = lora_name + self.not_supported = False + + self.module = type(org_module) + if isinstance(org_module, nn.Linear): + self.module_type = "linear" + self.shape = (org_module.out_features, org_module.in_features) + self.op = F.linear + self.dim = org_module.out_features + self.kw_dict = {} + elif isinstance(org_module, nn.Conv1d): + self.module_type = "conv1d" + self.shape = ( + org_module.out_channels, + org_module.in_channels, + *org_module.kernel_size, + ) + self.op = F.conv1d + self.dim = org_module.out_channels + self.kw_dict = { + "stride": org_module.stride, + "padding": org_module.padding, + "dilation": org_module.dilation, + "groups": org_module.groups, + } + elif isinstance(org_module, nn.Conv2d): + self.module_type = "conv2d" + self.shape = ( + org_module.out_channels, + org_module.in_channels, + *org_module.kernel_size, + ) + self.op = F.conv2d + self.dim = org_module.out_channels + self.kw_dict = { + "stride": org_module.stride, + "padding": org_module.padding, + "dilation": org_module.dilation, + "groups": org_module.groups, + } + elif isinstance(org_module, nn.Conv3d): + self.module_type = "conv3d" + self.shape = ( + org_module.out_channels, + org_module.in_channels, + *org_module.kernel_size, + ) + self.op = F.conv3d + self.dim = org_module.out_channels + self.kw_dict = { + "stride": org_module.stride, + "padding": org_module.padding, + "dilation": org_module.dilation, + "groups": org_module.groups, + } + elif isinstance(org_module, nn.LayerNorm): + self.module_type = "layernorm" + self.shape = tuple(org_module.normalized_shape) + self.op = F.layer_norm + self.dim = org_module.normalized_shape[0] + self.kw_dict = { + "normalized_shape": org_module.normalized_shape, + "eps": org_module.eps, + } + elif isinstance(org_module, nn.GroupNorm): + self.module_type = "groupnorm" + self.shape = (org_module.num_channels,) + self.op = F.group_norm + self.group_num = org_module.num_groups + self.dim = org_module.num_channels + self.kw_dict = {"num_groups": org_module.num_groups, "eps": org_module.eps} + else: + self.not_supported = True + self.module_type = "unknown" + + self.register_buffer("dtype_tensor", torch.tensor(0.0), persistent=False) + + self.is_quant = False + if isinstance(org_module, QuantLinears): + if not bypass_mode: + log_bypass() + self.is_quant = True + bypass_mode = True + if ( + isinstance(org_module, nn.Linear) + and org_module.__class__.__name__ != "Linear" + ): + if bypass_mode is None: + log_suspect() + bypass_mode = True + if bypass_mode == True: + self.is_quant = True + self.bypass_mode = bypass_mode + self.dropout = dropout + self.rank_dropout = rank_dropout + self.rank_dropout_scale = rank_dropout_scale + self.module_dropout = module_dropout + + ## Dropout things + # Since LoKr/LoHa/OFT/BOFT are hard to follow the rank_dropout definition from kohya + # We redefine the dropout procedure here. + # g(x) = WX + drop(Brank_drop(AX)) for LoCon(lora), bypass + # g(x) = WX + drop(ΔWX) for any algo except LoCon(lora), bypass + # g(x) = (W + Brank_drop(A))X for LoCon(lora), rebuid + # g(x) = (W + rank_drop(ΔW))X for any algo except LoCon(lora), rebuild + self.drop = nn.Identity() if dropout == 0 else nn.Dropout(dropout) + self.rank_drop = ( + nn.Identity() if rank_dropout == 0 else nn.Dropout(rank_dropout) + ) + + self.multiplier = multiplier + self.org_forward = org_module.forward + self.org_module = [org_module] + + @classmethod + def parametrize(cls, org_module, attr, *args, **kwargs): + from .full import FullModule + + if cls is FullModule: + raise RuntimeError("FullModule cannot be used for parametrize.") + target_param = getattr(org_module, attr) + kwargs["bypass_mode"] = False + if target_param.dim() == 2: + proxy_module = nn.Linear( + target_param.shape[0], target_param.shape[1], bias=False + ) + proxy_module.weight = target_param + elif target_param.dim() > 2: + module_type = [ + None, + None, + None, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + None, + None, + ][target_param.dim()] + proxy_module = module_type( + target_param.shape[0], + target_param.shape[1], + *target_param.shape[2:], + bias=False, + ) + proxy_module.weight = target_param + module_obj = cls("", proxy_module, *args, **kwargs) + module_obj.forward = module_obj.parametrize_forward + module_obj.to(target_param) + parametrize.register_parametrization(org_module, attr, module_obj) + return module_obj + + @classmethod + def algo_check(cls, state_dict, lora_name): + return any(f"{lora_name}.{k}" in state_dict for k in cls.weight_list_det) + + @classmethod + def extract_state_dict(cls, state_dict, lora_name): + return [state_dict.get(f"{lora_name}.{k}", None) for k in cls.weight_list] + + @classmethod + def make_module_from_state_dict(cls, lora_name, orig_module, *weights): + raise NotImplementedError + + @property + def dtype(self): + return self.dtype_tensor.dtype + + @property + def device(self): + return self.dtype_tensor.device + + @property + def org_weight(self): + return self.org_module[0].weight + + @org_weight.setter + def org_weight(self, value): + self.org_module[0].weight.data.copy_(value) + + def apply_to(self, **kwargs): + if self.not_supported: + return + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + + def restore(self): + if self.not_supported: + return + self.org_module[0].forward = self.org_forward + + def merge_to(self, multiplier=1.0): + if self.not_supported: + return + self_device = next(self.parameters()).device + self_dtype = next(self.parameters()).dtype + self.to(self.org_weight) + weight, bias = self.get_merged_weight( + multiplier, self.org_weight.shape, self.org_weight.device + ) + self.org_weight = weight.to(self.org_weight) + if bias is not None: + bias = bias.to(self.org_weight) + if self.org_module[0].bias is not None: + self.org_module[0].bias.data.copy_(bias) + else: + self.org_module[0].bias = nn.Parameter(bias) + self.to(self_device, self_dtype) + + def get_diff_weight(self, multiplier=1.0, shape=None, device=None): + raise NotImplementedError + + def get_merged_weight(self, multiplier=1.0, shape=None, device=None): + raise NotImplementedError + + @torch.no_grad() + def apply_max_norm(self, max_norm, device=None): + return None, None + + def bypass_forward_diff(self, x, scale=1): + raise NotImplementedError + + def bypass_forward(self, x, scale=1): + raise NotImplementedError + + def parametrize_forward(self, x: torch.Tensor, *args, **kwargs): + return self.get_merged_weight( + multiplier=self.multiplier, shape=x.shape, device=x.device + )[0].to(x.dtype) + + def forward(self, *args, **kwargs): + raise NotImplementedError diff --git a/lycoris/modules/boft.py b/lycoris/modules/boft.py new file mode 100644 index 0000000000000000000000000000000000000000..c547527a5c711934d3fb1cc427a15fd19caacb16 --- /dev/null +++ b/lycoris/modules/boft.py @@ -0,0 +1,255 @@ +from functools import cache +from math import log2 + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + +from .base import LycorisBaseModule +from ..functional import power2factorization +from ..logging import logger + + +@cache +def log_butterfly_factorize(dim, factor, result): + logger.info( + f"Use BOFT({int(log2(result[1]))}, {result[0]//2})" + f" (equivalent to factor={result[0]}) " + f"for {dim=} and {factor=}" + ) + + +def butterfly_factor(dimension: int, factor: int = -1) -> tuple[int, int]: + m, n = power2factorization(dimension, factor) + + if n == 0: + raise ValueError( + f"It is impossible to decompose {dimension} with factor {factor} under BOFT constraints." + ) + + log_butterfly_factorize(dimension, factor, (m, n)) + return m, n + + +class ButterflyOFTModule(LycorisBaseModule): + name = "boft" + support_module = { + "linear", + "conv1d", + "conv2d", + "conv3d", + } + weight_list = [ + "oft_blocks", + "rescale", + "alpha", + ] + weight_list_det = ["oft_blocks"] + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=0.0, + rank_dropout=0.0, + module_dropout=0.0, + use_tucker=False, + use_scalar=False, + rank_dropout_scale=False, + constraint=0, + rescaled=False, + bypass_mode=None, + **kwargs, + ): + super().__init__( + lora_name, + org_module, + multiplier, + dropout, + rank_dropout, + module_dropout, + rank_dropout_scale, + bypass_mode, + ) + if self.module_type not in self.support_module: + raise ValueError(f"{self.module_type} is not supported in BOFT algo.") + + out_dim = self.dim + b, m_exp = butterfly_factor(out_dim, lora_dim) + self.block_size = b + self.block_num = m_exp + # BOFT(m, b) + self.boft_b = b + self.boft_m = sum(int(i) for i in f"{m_exp-1:b}") + 1 + # block_num > block_size + self.rescaled = rescaled + self.constraint = constraint * out_dim + self.register_buffer("alpha", torch.tensor(constraint)) + self.oft_blocks = nn.Parameter( + torch.zeros(self.boft_m, self.block_num, self.block_size, self.block_size) + ) + if rescaled: + self.rescale = nn.Parameter( + torch.ones(out_dim, *(1 for _ in range(org_module.weight.dim() - 1))) + ) + + @classmethod + def algo_check(cls, state_dict, lora_name): + if f"{lora_name}.oft_blocks" in state_dict: + oft_blocks = state_dict[f"{lora_name}.oft_blocks"] + if oft_blocks.ndim == 4: + return True + return False + + @classmethod + def make_module_from_state_dict( + cls, lora_name, orig_module, oft_blocks, rescale, alpha + ): + m, n, s, _ = oft_blocks.shape + module = cls( + lora_name, + orig_module, + 1, + lora_dim=s, + constraint=float(alpha), + rescaled=rescale is not None, + ) + module.oft_blocks.copy_(oft_blocks) + if rescale is not None: + module.rescale.copy_(rescale) + return module + + @property + def I(self): + return torch.eye(self.block_size, device=self.device) + + def get_r(self): + I = self.I + # for Q = -Q^T + q = self.oft_blocks - self.oft_blocks.transpose(-1, -2) + normed_q = q + # Diag OFT style constrain + if self.constraint > 0: + q_norm = torch.norm(q) + 1e-8 + if q_norm > self.constraint: + normed_q = q * self.constraint / q_norm + # use float() to prevent unsupported type + r = (I + normed_q) @ (I - normed_q).float().inverse() + return r + + def make_weight(self, scale=1, device=None, diff=False): + m = self.boft_m + b = self.boft_b + r_b = b // 2 + r = self.get_r() + inp = org = self.org_weight.to(device, dtype=r.dtype) + + for i in range(m): + bi = r[i] # b_num, b_size, b_size + g = 2 + k = 2**i * r_b + if scale != 1: + bi = bi * scale + (1 - scale) * self.I + inp = ( + inp.unflatten(-1, (-1, g, k)) + .transpose(-2, -1) + .flatten(-3) + .unflatten(-1, (-1, b)) + ) + inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp) + inp = ( + inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3) + ) + + if self.rescaled: + inp = inp * self.rescale + + if diff: + inp = inp - org + + return inp.to(self.oft_blocks.dtype) + + def get_diff_weight(self, multiplier=1, shape=None, device=None): + diff = self.make_weight(scale=multiplier, device=device, diff=True) + if shape is not None: + diff = diff.view(shape) + return diff, None + + def get_merged_weight(self, multiplier=1, shape=None, device=None): + diff = self.make_weight(scale=multiplier, device=device) + if shape is not None: + diff = diff.view(shape) + return diff, None + + @torch.no_grad() + def apply_max_norm(self, max_norm, device=None): + orig_norm = self.oft_blocks.to(device).norm() + norm = torch.clamp(orig_norm, max_norm / 2) + desired = torch.clamp(norm, max=max_norm) + ratio = desired / norm + + scaled = norm != desired + if scaled: + self.oft_blocks *= ratio + + return scaled, orig_norm * ratio + + def _bypass_forward(self, x, scale=1, diff=False): + m = self.boft_m + b = self.boft_b + r_b = b // 2 + r = self.get_r() + inp = org = self.org_forward(x) + if self.op in {F.conv2d, F.conv1d, F.conv3d}: + inp = inp.transpose(1, -1) + + for i in range(m): + bi = r[i] # b_num, b_size, b_size + g = 2 + k = 2**i * r_b + if scale != 1: + bi = bi * scale + (1 - scale) * self.I + inp = ( + inp.unflatten(-1, (-1, g, k)) + .transpose(-2, -1) + .flatten(-3) + .unflatten(-1, (-1, b)) + ) + inp = torch.einsum("b i j, ... b j -> ... b i", bi, inp) + inp = ( + inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3) + ) + + if self.rescaled: + inp = inp * self.rescale.transpose(0, -1) + + if self.op in {F.conv2d, F.conv1d, F.conv3d}: + inp = inp.transpose(1, -1) + + if diff: + inp = inp - org + return inp + + def bypass_forward_diff(self, x, scale=1): + return self._bypass_forward(x, scale, diff=True) + + def bypass_forward(self, x, scale=1): + return self._bypass_forward(x, scale, diff=False) + + def forward(self, x, *args, **kwargs): + if self.module_dropout and self.training: + if torch.rand(1) < self.module_dropout: + return self.org_forward(x) + scale = self.multiplier + + if self.bypass_mode: + return self.bypass_forward(x, scale) + else: + w = self.make_weight(scale, x.device) + kw_dict = self.kw_dict | {"weight": w, "bias": self.org_module[0].bias} + return self.op(x, **kw_dict) diff --git a/lycoris/modules/diag_oft.py b/lycoris/modules/diag_oft.py new file mode 100644 index 0000000000000000000000000000000000000000..3805995ac7c8be57ce5c3036392c441d3cb670dc --- /dev/null +++ b/lycoris/modules/diag_oft.py @@ -0,0 +1,217 @@ +from functools import cache + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base import LycorisBaseModule +from ..functional import factorization +from ..logging import logger + + +@cache +def log_oft_factorize(dim, factor, num, bdim): + logger.info( + f"Use OFT(block num: {num}, block dim: {bdim})" + f" (equivalent to lora_dim={num}) " + f"for {dim=} and lora_dim={factor=}" + ) + + +class DiagOFTModule(LycorisBaseModule): + name = "diag-oft" + support_module = { + "linear", + "conv1d", + "conv2d", + "conv3d", + } + weight_list = [ + "oft_blocks", + "rescale", + "alpha", + ] + weight_list_det = ["oft_blocks"] + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=0.0, + rank_dropout=0.0, + module_dropout=0.0, + use_tucker=False, + use_scalar=False, + rank_dropout_scale=False, + constraint=0, + rescaled=False, + bypass_mode=None, + **kwargs, + ): + super().__init__( + lora_name, + org_module, + multiplier, + dropout, + rank_dropout, + module_dropout, + rank_dropout_scale, + bypass_mode, + ) + if self.module_type not in self.support_module: + raise ValueError(f"{self.module_type} is not supported in Diag-OFT algo.") + + out_dim = self.dim + self.block_size, self.block_num = factorization(out_dim, lora_dim) + # block_num > block_size + self.rescaled = rescaled + self.constraint = constraint * out_dim + self.register_buffer("alpha", torch.tensor(constraint)) + self.oft_blocks = nn.Parameter( + torch.zeros(self.block_num, self.block_size, self.block_size) + ) + if rescaled: + self.rescale = nn.Parameter( + torch.ones(out_dim, *(1 for _ in range(org_module.weight.dim() - 1))) + ) + + log_oft_factorize( + dim=out_dim, + factor=lora_dim, + num=self.block_num, + bdim=self.block_size, + ) + + @classmethod + def algo_check(cls, state_dict, lora_name): + if f"{lora_name}.oft_blocks" in state_dict: + oft_blocks = state_dict[f"{lora_name}.oft_blocks"] + if oft_blocks.ndim == 3: + return True + return False + + @classmethod + def make_module_from_state_dict( + cls, lora_name, orig_module, oft_blocks, rescale, alpha + ): + n, s, _ = oft_blocks.shape + module = cls( + lora_name, + orig_module, + 1, + lora_dim=s, + constraint=float(alpha), + rescaled=rescale is not None, + ) + module.oft_blocks.copy_(oft_blocks) + if rescale is not None: + module.rescale.copy_(rescale) + return module + + @property + def I(self): + return torch.eye(self.block_size, device=self.device) + + def get_r(self): + I = self.I + # for Q = -Q^T + q = self.oft_blocks - self.oft_blocks.transpose(1, 2) + normed_q = q + if self.constraint > 0: + q_norm = torch.norm(q) + 1e-8 + if q_norm > self.constraint: + normed_q = q * self.constraint / q_norm + # use float() to prevent unsupported type + r = (I + normed_q) @ (I - normed_q).float().inverse() + return r + + def make_weight(self, scale=1, device=None, diff=False): + r = self.get_r() + _, *shape = self.org_weight.shape + org_weight = self.org_weight.to(device, dtype=r.dtype) + org_weight = org_weight.view(self.block_num, self.block_size, *shape) + # Init R=0, so add I on it to ensure the output of step0 is original model output + weight = torch.einsum( + "k n m, k n ... -> k m ...", + self.rank_drop(r * scale) - scale * self.I + (0 if diff else self.I), + org_weight, + ).view(-1, *shape) + if self.rescaled: + weight = self.rescale * weight + if diff: + weight = weight + (self.rescale - 1) * org_weight + return weight.to(self.oft_blocks.dtype) + + def get_diff_weight(self, multiplier=1, shape=None, device=None): + diff = self.make_weight(scale=multiplier, device=device, diff=True) + if shape is not None: + diff = diff.view(shape) + return diff, None + + def get_merged_weight(self, multiplier=1, shape=None, device=None): + diff = self.make_weight(scale=multiplier, device=device) + if shape is not None: + diff = diff.view(shape) + return diff, None + + @torch.no_grad() + def apply_max_norm(self, max_norm, device=None): + orig_norm = self.oft_blocks.to(device).norm() + norm = torch.clamp(orig_norm, max_norm / 2) + desired = torch.clamp(norm, max=max_norm) + ratio = desired / norm + + scaled = norm != desired + if scaled: + self.oft_blocks *= ratio + + return scaled, orig_norm * ratio + + def _bypass_forward(self, x, scale=1, diff=False): + r = self.get_r() + org_out = self.org_forward(x) + if self.op in {F.conv2d, F.conv1d, F.conv3d}: + org_out = org_out.transpose(1, -1) + *shape, _ = org_out.shape + org_out = org_out.view(*shape, self.block_num, self.block_size) + mask = neg_mask = 1 + if self.dropout != 0 and self.training: + mask = torch.ones_like(org_out) + mask = self.drop(mask) + neg_mask = torch.max(mask) - mask + oft_out = torch.einsum( + "k n m, ... k n -> ... k m", + r * scale * mask + (1 - scale) * self.I * neg_mask, + org_out, + ) + if diff: + out = out - org_out + out = oft_out.view(*shape, -1) + if self.rescaled: + out = self.rescale.transpose(-1, 0) * out + out = out + (self.rescale.transpose(-1, 0) - 1) * org_out + if self.op in {F.conv2d, F.conv1d, F.conv3d}: + out = out.transpose(1, -1) + return out + + def bypass_forward_diff(self, x, scale=1): + return self._bypass_forward(x, scale, diff=True) + + def bypass_forward(self, x, scale=1): + return self._bypass_forward(x, scale, diff=False) + + def forward(self, x: torch.Tensor, *args, **kwargs): + if self.module_dropout and self.training: + if torch.rand(1) < self.module_dropout: + return self.org_forward(x) + scale = self.multiplier + + if self.bypass_mode: + return self.bypass_forward(x, scale) + else: + w = self.make_weight(scale, x.device) + kw_dict = self.kw_dict | {"weight": w, "bias": self.org_module[0].bias} + return self.op(x, **kw_dict) diff --git a/lycoris/modules/dylora.py b/lycoris/modules/dylora.py new file mode 100644 index 0000000000000000000000000000000000000000..4138142e819f6a7146baa73bf36efec22926b166 --- /dev/null +++ b/lycoris/modules/dylora.py @@ -0,0 +1,156 @@ +import math +import random + +import torch +import torch.nn as nn + +from .base import LycorisBaseModule +from ..utils import product + + +class DyLoraModule(LycorisBaseModule): + support_module = { + "linear", + "conv1d", + "conv2d", + "conv3d", + } + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=0.0, + rank_dropout=0.0, + module_dropout=0.0, + use_tucker=False, + block_size=4, + use_scalar=False, + rank_dropout_scale=False, + weight_decompose=False, + bypass_mode=None, + rs_lora=False, + train_on_input=False, + **kwargs, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__( + lora_name, + org_module, + multiplier, + dropout, + rank_dropout, + module_dropout, + rank_dropout_scale, + bypass_mode, + ) + if self.module_type not in self.support_module: + raise ValueError(f"{self.module_type} is not supported in IA^3 algo.") + assert lora_dim % block_size == 0, "lora_dim must be a multiple of block_size" + self.block_count = lora_dim // block_size + self.block_size = block_size + + shape = ( + self.shape[0], + product(self.shape[1:]), + ) + + self.lora_dim = lora_dim + self.up_list = nn.ParameterList( + [torch.empty(shape[0], self.block_size) for i in range(self.block_count)] + ) + self.down_list = nn.ParameterList( + [torch.empty(self.block_size, shape[1]) for i in range(self.block_count)] + ) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # Need more experiences on init method + for v in self.down_list: + torch.nn.init.kaiming_uniform_(v, a=math.sqrt(5)) + for v in self.up_list: + torch.nn.init.zeros_(v) + + def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False): + return + + def custom_state_dict(self): + destination = {} + destination["alpha"] = self.alpha + destination["lora_up.weight"] = nn.Parameter( + torch.concat(list(self.up_list), dim=1) + ) + destination["lora_down.weight"] = nn.Parameter( + torch.concat(list(self.down_list)).reshape( + self.lora_dim, -1, *self.shape[2:] + ) + ) + return destination + + def get_weight(self, rank): + b = math.ceil(rank / self.block_size) + down = torch.concat( + list(i.data for i in self.down_list[:b]) + list(self.down_list[b : (b + 1)]) + ) + up = torch.concat( + list(i.data for i in self.up_list[:b]) + list(self.up_list[b : (b + 1)]), + dim=1, + ) + return down, up, self.alpha / (b + 1) + + def get_random_rank_weight(self): + b = random.randint(0, self.block_count - 1) + return self.get_weight(b * self.block_size) + + def get_diff_weight(self, multiplier=1, shape=None, device=None, rank=None): + if rank is None: + down, up, scale = self.get_random_rank_weight() + else: + down, up, scale = self.get_weight(rank) + w = up @ (down * (scale * multiplier)) + if device is not None: + w = w.to(device) + if shape is not None: + w = w.view(shape) + else: + w = w.view(self.shape) + return w, None + + def get_merged_weight(self, multiplier=1, shape=None, device=None, rank=None): + diff, _ = self.get_diff_weight(multiplier, shape, device, rank) + return diff + self.org_weight, None + + def bypass_forward_diff(self, x, scale=1, rank=None): + if rank is None: + down, up, gamma = self.get_random_rank_weight() + else: + down, up, scale = self.get_weight(rank) + down = down.view(self.lora_dim, -1, *self.shape[2:]) + up = up.view(-1, self.lora_dim, *(1 for _ in self.shape[2:])) + scale = scale * gamma + return self.op(self.op(x, down, **self.kw_dict), up) + + def bypass_forward(self, x, scale=1, rank=None): + return self.org_forward(x) + self.bypass_forward_diff(x, scale, rank) + + def forward(self, x, *args, **kwargs): + if self.module_dropout and self.training: + if torch.rand(1) < self.module_dropout: + return self.org_forward(x) + if self.bypass_mode: + return self.bypass_forward(x, self.multiplier) + else: + weight = self.get_merged_weight(multiplier=self.multiplier)[0] + bias = ( + None + if self.org_module[0].bias is None + else self.org_module[0].bias.data + ) + return self.op(x, weight, bias, **self.kw_dict) diff --git a/lycoris/modules/full.py b/lycoris/modules/full.py new file mode 100644 index 0000000000000000000000000000000000000000..d4b0dc27fc47dec6759d086ba8b8e4726acfe875 --- /dev/null +++ b/lycoris/modules/full.py @@ -0,0 +1,214 @@ +from functools import cache + +import torch +import torch.nn as nn + +from .base import LycorisBaseModule +from ..logging import logger + + +@cache +def log_bypass_override(): + return logger.warning( + "Automatic Bypass-Mode detected in algo=full, " + "override with bypass_mode=False since algo=full not support bypass mode. " + "If you are using quantized model which require bypass mode, please don't use algo=full. " + ) + + +class FullModule(LycorisBaseModule): + name = "full" + support_module = { + "linear", + "conv1d", + "conv2d", + "conv3d", + } + weight_list = ["diff", "diff_b"] + weight_list_det = ["diff"] + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=0.0, + rank_dropout=0.0, + module_dropout=0.0, + use_tucker=False, + use_scalar=False, + rank_dropout_scale=False, + bypass_mode=None, + **kwargs, + ): + org_bypass = bypass_mode + super().__init__( + lora_name, + org_module, + multiplier, + dropout, + rank_dropout, + module_dropout, + rank_dropout_scale, + bypass_mode, + ) + if bypass_mode and org_bypass is None: + self.bypass_mode = False + log_bypass_override() + + if self.module_type not in self.support_module: + raise ValueError(f"{self.module_type} is not supported in Full algo.") + + if self.is_quant: + raise ValueError( + "Quant Linear is not supported and meaningless in Full algo." + ) + + if self.bypass_mode: + raise ValueError("bypass mode is not supported in Full algo.") + + self.weight = nn.Parameter(torch.zeros_like(org_module.weight)) + if org_module.bias is not None: + self.bias = nn.Parameter(torch.zeros_like(org_module.bias)) + else: + self.bias = None + self.is_diff = True + self._org_weight = [self.org_module[0].weight.data.cpu().clone()] + if self.org_module[0].bias is not None: + self.org_bias = [self.org_module[0].bias.data.cpu().clone()] + else: + self.org_bias = None + + @classmethod + def make_module_from_state_dict(cls, lora_name, orig_module, diff, diff_b): + module = cls( + lora_name, + orig_module, + 1, + ) + module.weight.copy_(diff) + if diff_b is not None: + if orig_module.bias is not None: + module.bias.copy_(diff_b) + else: + module.bias = nn.Parameter(diff_b) + module.is_diff = True + return module + + @property + def org_weight(self): + return self._org_weight[0] + + @org_weight.setter + def org_weight(self, value): + self.org_module[0].weight.data.copy_(value) + + def apply_to(self, **kwargs): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + self.weight.data.add_(self.org_module[0].weight.data) + self._org_weight = [self.org_module[0].weight.data.cpu().clone()] + delattr(self.org_module[0], "weight") + if self.org_module[0].bias is not None: + self.bias.data.add_(self.org_module[0].bias.data) + self.org_bias = [self.org_module[0].bias.data.cpu().clone()] + delattr(self.org_module[0], "bias") + else: + self.org_bias = None + self.is_diff = False + + def restore(self): + self.org_module[0].forward = self.org_forward + self.org_module[0].weight = nn.Parameter(self._org_weight[0]) + if self.org_bias is not None: + self.org_module[0].bias = nn.Parameter(self.org_bias[0]) + + def custom_state_dict(self): + sd = {"diff": self.weight.data.cpu() - self._org_weight[0]} + if self.bias is not None: + sd["diff_b"] = self.bias.data.cpu() - self.org_bias[0] + return sd + + def load_weight_prehook( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + diff_weight = state_dict.pop(f"{prefix}diff") + state_dict[f"{prefix}weight"] = diff_weight + self.weight.data.to(diff_weight) + if f"{prefix}diff_b" in state_dict: + diff_bias = state_dict.pop(f"{prefix}diff_b") + state_dict[f"{prefix}bias"] = diff_bias + self.bias.data.to(diff_bias) + + def make_weight(self, scale=1, device=None): + drop = ( + torch.rand(self.dim, device=device) > self.rank_dropout + if self.rank_dropout and self.training + else 1 + ) + if drop != 1 or scale != 1 or self.is_diff: + diff_w, diff_b = self.get_diff_weight(scale, device=device) + weight = self.org_weight + diff_w * drop + if self.org_bias is not None: + bias = self.org_bias + diff_b * drop + else: + bias = None + else: + weight = self.weight + bias = self.bias + return weight, bias + + def get_diff_weight(self, multiplier=1, shape=None, device=None): + if self.is_diff: + diff_b = None + if self.bias is not None: + diff_b = self.bias * multiplier + return self.weight * multiplier, diff_b + org_weight = self.org_module[0].weight.to(device, dtype=self.weight.dtype) + diff = self.weight.to(device) - org_weight + diff_b = None + if shape: + diff = diff.view(shape) + if self.bias is not None: + org_bias = self.org_module[0].bias.to(device, dtype=self.bias.dtype) + diff_b = self.bias.to(device) - org_bias + if device is not None: + diff = diff.to(device) + if self.bias is not None: + diff_b = diff_b.to(device) + if multiplier != 1: + diff = diff * multiplier + if diff_b is not None: + diff_b = diff_b * multiplier + return diff * multiplier, diff_b + + def get_merged_weight(self, multiplier=1, shape=None, device=None): + weight, bias = self.make_weight(multiplier, device) + if shape is not None: + weight = weight.view(shape) + if bias is not None: + bias = bias.view(shape[0]) + return weight, bias + + def forward(self, x: torch.Tensor, *args, **kwargs): + if ( + self.module_dropout + and self.training + and torch.rand(1) < self.module_dropout + ): + original = True + else: + original = False + if original: + return self.org_forward(x) + scale = self.multiplier + weight, bias = self.make_weight(scale, x.device) + kw_dict = self.kw_dict | {"weight": weight, "bias": bias} + return self.op(x, **kw_dict) diff --git a/lycoris/modules/glora.py b/lycoris/modules/glora.py new file mode 100644 index 0000000000000000000000000000000000000000..45a47d8de6749dd333594d4e008c47e42d873eaf --- /dev/null +++ b/lycoris/modules/glora.py @@ -0,0 +1,262 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base import LycorisBaseModule +from ..functional import tucker_weight_from_conv + + +class GLoRAModule(LycorisBaseModule): + name = "glora" + support_module = { + "linear", + "conv1d", + "conv2d", + "conv3d", + } + weight_list = [ + "a1.weight", + "a2.weight", + "b1.weight", + "b2.weight", + "bm.weight", + "alpha", + ] + weight_list_det = ["a1.weight"] + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=0.0, + rank_dropout=0.0, + module_dropout=0.0, + use_tucker=False, + use_scalar=False, + rank_dropout_scale=False, + weight_decompose=False, + bypass_mode=None, + rs_lora=False, + **kwargs, + ): + """ + f(x) = WX + WAX + BX, where A and B are low-rank matrices + bypass_forward(x) = W(X+A(X)) + B(X) + bypass_forward_diff(x) = W(A(X)) + B(X) + get_merged_weight() = W + WA + B + get_diff_weight() = WA + B + """ + super().__init__( + lora_name, + org_module, + multiplier, + dropout, + rank_dropout, + module_dropout, + rank_dropout_scale, + bypass_mode, + ) + if self.module_type not in self.support_module: + raise ValueError(f"{self.module_type} is not supported in GLoRA algo.") + self.lora_dim = lora_dim + self.tucker = False + self.rs_lora = rs_lora + + if self.module_type.startswith("conv"): + self.isconv = True + # For general LoCon + in_dim = org_module.in_channels + k_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + out_dim = org_module.out_channels + use_tucker = use_tucker and all(i == 1 for i in k_size) + self.down_op = self.op + self.up_op = self.op + + # A + self.a2 = self.module(in_dim, lora_dim, 1, bias=False) + self.a1 = self.module(lora_dim, in_dim, 1, bias=False) + + # B + if use_tucker and any(i != 1 for i in k_size): + self.b2 = self.module(in_dim, lora_dim, 1, bias=False) + self.bm = self.module( + lora_dim, lora_dim, k_size, stride, padding, bias=False + ) + self.tucker = True + else: + self.b2 = self.module( + in_dim, lora_dim, k_size, stride, padding, bias=False + ) + self.b1 = self.module(lora_dim, out_dim, 1, bias=False) + else: + self.isconv = False + self.down_op = F.linear + self.up_op = F.linear + in_dim = org_module.in_features + out_dim = org_module.out_features + self.a2 = nn.Linear(in_dim, lora_dim, bias=False) + self.a1 = nn.Linear(lora_dim, in_dim, bias=False) + self.b2 = nn.Linear(in_dim, lora_dim, bias=False) + self.b1 = nn.Linear(lora_dim, out_dim, bias=False) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = lora_dim if alpha is None or alpha == 0 else alpha + + r_factor = lora_dim + if self.rs_lora: + r_factor = math.sqrt(r_factor) + + self.scale = alpha / r_factor + + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + if use_scalar: + self.scalar = nn.Parameter(torch.tensor(0.0)) + else: + self.register_buffer("scalar", torch.tensor(1.0), persistent=False) + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.a1.weight, a=math.sqrt(5)) + torch.nn.init.kaiming_uniform_(self.b1.weight, a=math.sqrt(5)) + if use_scalar: + torch.nn.init.kaiming_uniform_(self.a2.weight, a=math.sqrt(5)) + torch.nn.init.kaiming_uniform_(self.b2.weight, a=math.sqrt(5)) + else: + torch.nn.init.zeros_(self.a2.weight) + torch.nn.init.zeros_(self.b2.weight) + + @classmethod + def make_module_from_state_dict( + cls, lora_name, orig_module, a1, a2, b1, b2, bm, alpha + ): + module = cls( + lora_name, + orig_module, + 1, + a2.size(0), + float(alpha), + use_tucker=bm is not None, + ) + module.a1.weight.data.copy_(a1) + module.a2.weight.data.copy_(a2) + module.b1.weight.data.copy_(b1) + module.b2.weight.data.copy_(b2) + if bm is not None: + module.bm.weight.data.copy_(bm) + return module + + def custom_state_dict(self): + destination = {} + destination["alpha"] = self.alpha + destination["a1.weight"] = self.a1.weight + destination["a2.weight"] = self.a2.weight * self.scalar + destination["b1.weight"] = self.b1.weight + destination["b2.weight"] = self.b2.weight * self.scalar + if self.tucker: + destination["bm.weight"] = self.bm.weight + return destination + + def load_weight_hook(self, module: nn.Module, incompatible_keys): + missing_keys = incompatible_keys.missing_keys + for key in missing_keys: + if "scalar" in key: + del missing_keys[missing_keys.index(key)] + if isinstance(self.scalar, nn.Parameter): + self.scalar.data.copy_(torch.ones_like(self.scalar)) + elif getattr(self, "scalar", None) is not None: + self.scalar.copy_(torch.ones_like(self.scalar)) + else: + self.register_buffer( + "scalar", torch.ones_like(self.scalar), persistent=False + ) + + def make_weight(self, device=None): + wa1 = self.a1.weight.view(self.a1.weight.size(0), -1) + wa2 = self.a2.weight.view(self.a2.weight.size(0), -1) + orig = self.org_weight + + if self.tucker: + wb = tucker_weight_from_conv(self.b1.weight, self.b2.weight, self.bm.weight) + else: + wb1 = self.b1.weight.view(self.b1.weight.size(0), -1) + wb2 = self.b2.weight.view(self.b2.weight.size(0), -1) + wb = wb1 @ wb2 + wb = wb.view(*orig.shape) + if orig.dim() > 2: + w_wa1 = torch.einsum("o i ..., i j -> o j ...", orig, wa1) + w_wa2 = torch.einsum("o i ..., i j -> o j ...", w_wa1, wa2) + else: + w_wa2 = (orig @ wa1) @ wa2 + return (wb + w_wa2) * self.scale * self.scalar + + def get_diff_weight(self, multiplier=1.0, shape=None, device=None): + weight = self.make_weight(device) * multiplier + if shape is not None: + weight = weight.view(shape) + return weight, None + + def get_merged_weight(self, multiplier=1, shape=None, device=None): + diff_w, _ = self.get_diff_weight(multiplier, shape, device) + return self.org_weight + diff_w, None + + def _bypass_forward(self, x, scale=1, diff=False): + scale = self.scale * scale + ax_mid = self.a2(x) * scale + bx_mid = self.b2(x) * scale + + if self.rank_dropout and self.training: + drop_a = ( + torch.rand(self.lora_dim, device=ax_mid.device) < self.rank_dropout + ).to(ax_mid.dtype) + drop_b = ( + torch.rand(self.lora_dim, device=bx_mid.device) < self.rank_dropout + ).to(bx_mid.dtype) + if self.rank_dropout_scale: + drop_a /= drop_a.mean() + drop_b /= drop_b.mean() + if (dims := len(x.shape)) == 4: + drop_a = drop_a.view(1, -1, 1, 1) + drop_b = drop_b.view(1, -1, 1, 1) + else: + drop_a = drop_a.view(*[1] * (dims - 1), -1) + drop_b = drop_b.view(*[1] * (dims - 1), -1) + ax_mid = ax_mid * drop_a + bx_mid = bx_mid * drop_b + return ( + self.org_forward( + (0 if diff else x) + self.drop(self.a1(ax_mid)) * self.scale + ) + + self.drop(self.b1(bx_mid)) * self.scale + ) + + def bypass_forward_diff(self, x, scale=1): + return self._bypass_forward(x, scale=scale, diff=True) + + def bypass_forward(self, x, scale=1): + return self._bypass_forward(x, scale=scale, diff=False) + + def forward(self, x, *args, **kwargs): + if self.module_dropout and self.training: + if torch.rand(1) < self.module_dropout: + return self.org_forward(x) + if self.bypass_mode: + return self.bypass_forward(x, self.multiplier) + else: + weight = ( + self.org_module[0].weight.data.to(self.dtype) + + self.get_diff_weight(multiplier=self.multiplier)[0] + ) + bias = ( + None + if self.org_module[0].bias is None + else self.org_module[0].bias.data + ) + return self.op(x, weight, bias, **self.kw_dict) diff --git a/lycoris/modules/ia3.py b/lycoris/modules/ia3.py new file mode 100644 index 0000000000000000000000000000000000000000..eeeaa1beae902fdfa870a429f03f805c51e13929 --- /dev/null +++ b/lycoris/modules/ia3.py @@ -0,0 +1,142 @@ +import torch +import torch.nn as nn + +from .base import LycorisBaseModule + + +class IA3Module(LycorisBaseModule): + name = "ia3" + support_module = { + "linear", + "conv1d", + "conv2d", + "conv3d", + } + weight_list = ["weight", "on_input"] + weight_list_det = ["on_input"] + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=0.0, + rank_dropout=0.0, + module_dropout=0.0, + use_tucker=False, + use_scalar=False, + rank_dropout_scale=False, + weight_decompose=False, + bypass_mode=None, + rs_lora=False, + train_on_input=False, + **kwargs, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__( + lora_name, + org_module, + multiplier, + dropout, + rank_dropout, + module_dropout, + rank_dropout_scale, + bypass_mode, + ) + if self.module_type not in self.support_module: + raise ValueError(f"{self.module_type} is not supported in IA^3 algo.") + + if self.module_type.startswith("conv"): + self.isconv = True + in_dim = org_module.in_channels + out_dim = org_module.out_channels + if train_on_input: + train_dim = in_dim + else: + train_dim = out_dim + self.weight = nn.Parameter( + torch.empty(1, train_dim, *(1 for _ in self.shape[2:])) + ) + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + if train_on_input: + train_dim = in_dim + else: + train_dim = out_dim + + self.weight = nn.Parameter(torch.empty(train_dim)) + + # Need more experiences on init method + torch.nn.init.constant_(self.weight, 0) + self.train_input = train_on_input + self.register_buffer("on_input", torch.tensor(int(train_on_input))) + + @classmethod + def make_module_from_state_dict(cls, lora_name, orig_module, weight): + module = cls( + lora_name, + orig_module, + 1, + ) + module.weight.data.copy_(weight) + return module + + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + + def make_weight(self, multiplier=1, shape=None, device=None, diff=False): + weight = self.weight * multiplier + int(not diff) + if self.train_input: + diff = self.org_weight * weight + else: + diff = self.org_weight.transpose(0, 1) * weight + diff = diff.transpose(0, 1) + if shape is not None: + diff = diff.view(shape) + if device is not None: + diff = diff.to(device) + return diff + + def get_diff_weight(self, multiplier=1, shape=None, device=None): + diff = self.make_weight( + multiplier=multiplier, shape=shape, device=device, diff=True + ) + return diff, None + + def get_merged_weight(self, multiplier=1, shape=None, device=None): + diff = self.make_weight(multiplier=multiplier, shape=shape, device=device) + return diff, None + + def _bypass_forward(self, x, scale=1, diff=False): + weight = self.weight * scale + int(not diff) + if self.train_input: + x = x * weight + out = self.org_forward(x) + if not self.train_input: + out = out * weight + return out + + def bypass_forward_diff(self, x, scale=1): + return self._bypass_forward(x, scale, diff=True) + + def bypass_forward(self, x, scale=1): + return self._bypass_forward(x, scale, diff=False) + + def forward(self, x, *args, **kwargs): + if self.module_dropout and self.training: + if torch.rand(1) < self.module_dropout: + return self.org_forward(x) + if self.bypass_mode: + return self.bypass_forward(x, self.multiplier) + else: + weight = self.get_merged_weight(multiplier=self.multiplier)[0] + bias = ( + None + if self.org_module[0].bias is None + else self.org_module[0].bias.data + ) + return self.op(x, weight, bias, **self.kw_dict) diff --git a/lycoris/modules/locon.py b/lycoris/modules/locon.py new file mode 100644 index 0000000000000000000000000000000000000000..0338684a28ce7f0b648a471914fcad0744f62cab --- /dev/null +++ b/lycoris/modules/locon.py @@ -0,0 +1,332 @@ +import math +from functools import cache + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base import LycorisBaseModule +from ..functional.general import rebuild_tucker +from ..logging import logger + + +@cache +def log_wd(): + return logger.warning( + "Using weight_decompose=True with LoRA (DoRA) will ignore network_dropout." + "Only rank dropout and module dropout will be applied" + ) + + +class LoConModule(LycorisBaseModule): + name = "locon" + support_module = { + "linear", + "conv1d", + "conv2d", + "conv3d", + } + weight_list = [ + "lora_up.weight", + "lora_down.weight", + "lora_mid.weight", + "alpha", + "dora_scale", + ] + weight_list_det = ["lora_up.weight"] + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=0.0, + rank_dropout=0.0, + module_dropout=0.0, + use_tucker=False, + use_scalar=False, + rank_dropout_scale=False, + weight_decompose=False, + wd_on_out=False, + bypass_mode=None, + rs_lora=False, + **kwargs, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__( + lora_name, + org_module, + multiplier, + dropout, + rank_dropout, + module_dropout, + rank_dropout_scale, + bypass_mode, + ) + if self.module_type not in self.support_module: + raise ValueError(f"{self.module_type} is not supported in LoRA/LoCon algo.") + self.lora_dim = lora_dim + self.tucker = False + self.rs_lora = rs_lora + + if self.module_type.startswith("conv"): + self.isconv = True + # For general LoCon + in_dim = org_module.in_channels + k_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + out_dim = org_module.out_channels + use_tucker = use_tucker and any(i != 1 for i in k_size) + self.down_op = self.op + self.up_op = self.op + if use_tucker and any(i != 1 for i in k_size): + self.lora_down = self.module(in_dim, lora_dim, 1, bias=False) + self.lora_mid = self.module( + lora_dim, lora_dim, k_size, stride, padding, bias=False + ) + self.tucker = True + else: + self.lora_down = self.module( + in_dim, lora_dim, k_size, stride, padding, bias=False + ) + self.lora_up = self.module(lora_dim, out_dim, 1, bias=False) + elif isinstance(org_module, nn.Linear): + self.isconv = False + self.down_op = F.linear + self.up_op = F.linear + in_dim = org_module.in_features + out_dim = org_module.out_features + self.lora_down = nn.Linear(in_dim, lora_dim, bias=False) + self.lora_up = nn.Linear(lora_dim, out_dim, bias=False) + else: + raise NotImplementedError + + self.wd = weight_decompose + self.wd_on_out = wd_on_out + if self.wd: + org_weight = org_module.weight.cpu().clone().float() + self.dora_norm_dims = org_weight.dim() - 1 + if self.wd_on_out: + self.dora_scale = nn.Parameter( + torch.norm( + org_weight.reshape(org_weight.shape[0], -1), + dim=1, + keepdim=True, + ).reshape(org_weight.shape[0], *[1] * self.dora_norm_dims) + ).float() + else: + self.dora_scale = nn.Parameter( + torch.norm( + org_weight.transpose(1, 0).reshape(org_weight.shape[1], -1), + dim=1, + keepdim=True, + ) + .reshape(org_weight.shape[1], *[1] * self.dora_norm_dims) + .transpose(1, 0) + ).float() + + if dropout: + self.dropout = nn.Dropout(dropout) + if self.wd: + log_wd() + else: + self.dropout = nn.Identity() + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = lora_dim if alpha is None or alpha == 0 else alpha + + r_factor = lora_dim + if self.rs_lora: + r_factor = math.sqrt(r_factor) + + self.scale = alpha / r_factor + + self.register_buffer("alpha", torch.tensor(alpha * (lora_dim / r_factor))) + + if use_scalar: + self.scalar = nn.Parameter(torch.tensor(0.0)) + else: + self.register_buffer("scalar", torch.tensor(1.0), persistent=False) + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + if use_scalar: + torch.nn.init.kaiming_uniform_(self.lora_up.weight, a=math.sqrt(5)) + else: + torch.nn.init.constant_(self.lora_up.weight, 0) + if self.tucker: + torch.nn.init.kaiming_uniform_(self.lora_mid.weight, a=math.sqrt(5)) + + @classmethod + def make_module_from_state_dict( + cls, lora_name, orig_module, up, down, mid, alpha, dora_scale + ): + module = cls( + lora_name, + orig_module, + 1, + down.size(0), + float(alpha), + use_tucker=mid is not None, + weight_decompose=dora_scale is not None, + ) + module.lora_up.weight.data.copy_(up) + module.lora_down.weight.data.copy_(down) + if mid is not None: + module.lora_mid.weight.data.copy_(mid) + if dora_scale is not None: + module.dora_scale.copy_(dora_scale) + return module + + def load_weight_hook(self, module: nn.Module, incompatible_keys): + missing_keys = incompatible_keys.missing_keys + for key in missing_keys: + if "scalar" in key: + del missing_keys[missing_keys.index(key)] + if isinstance(self.scalar, nn.Parameter): + self.scalar.data.copy_(torch.ones_like(self.scalar)) + elif getattr(self, "scalar", None) is not None: + self.scalar.copy_(torch.ones_like(self.scalar)) + else: + self.register_buffer( + "scalar", torch.ones_like(self.scalar), persistent=False + ) + + def make_weight(self, device=None): + wa = self.lora_up.weight.to(device) + wb = self.lora_down.weight.to(device) + if self.tucker: + t = self.lora_mid.weight + wa = wa.view(wa.size(0), -1).transpose(0, 1) + wb = wb.view(wb.size(0), -1) + weight = rebuild_tucker(t, wa, wb) + else: + weight = wa.view(wa.size(0), -1) @ wb.view(wb.size(0), -1) + + weight = weight.view(self.shape) + if self.training and self.rank_dropout: + drop = (torch.rand(weight.size(0), device=device) > self.rank_dropout).to( + weight.dtype + ) + drop = drop.view(-1, *[1] * len(weight.shape[1:])) + if self.rank_dropout_scale: + drop /= drop.mean() + weight *= drop + + return weight * self.scalar.to(device) + + def get_diff_weight(self, multiplier=1, shape=None, device=None): + scale = self.scale * multiplier + diff = self.make_weight(device=device) * scale + if shape is not None: + diff = diff.view(shape) + if device is not None: + diff = diff.to(device) + return diff, None + + def get_merged_weight(self, multiplier=1, shape=None, device=None): + diff = self.get_diff_weight(multiplier=1, shape=shape, device=device)[0] + weight = self.org_weight + if self.wd: + merged = self.apply_weight_decompose(weight + diff, multiplier) + else: + merged = weight + diff * multiplier + return merged, None + + def apply_weight_decompose(self, weight, multiplier=1): + weight = weight.to(self.dora_scale.dtype) + if self.wd_on_out: + weight_norm = ( + weight.reshape(weight.shape[0], -1) + .norm(dim=1) + .reshape(weight.shape[0], *[1] * self.dora_norm_dims) + ) + torch.finfo(weight.dtype).eps + else: + weight_norm = ( + weight.transpose(0, 1) + .reshape(weight.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight.shape[1], *[1] * self.dora_norm_dims) + .transpose(0, 1) + ) + torch.finfo(weight.dtype).eps + + scale = self.dora_scale.to(weight.device) / weight_norm + if multiplier != 1: + scale = multiplier * (scale - 1) + 1 + + return weight * scale + + def custom_state_dict(self): + destination = {} + if self.wd: + destination["dora_scale"] = self.dora_scale + destination["alpha"] = self.alpha + destination["lora_up.weight"] = self.lora_up.weight * self.scalar + destination["lora_down.weight"] = self.lora_down.weight + if self.tucker: + destination["lora_mid.weight"] = self.lora_mid.weight + return destination + + @torch.no_grad() + def apply_max_norm(self, max_norm, device=None): + orig_norm = self.make_weight(device).norm() * self.scale + norm = torch.clamp(orig_norm, max_norm / 2) + desired = torch.clamp(norm, max=max_norm) + ratio = desired.cpu() / norm.cpu() + + scaled = norm != desired + if scaled: + self.scalar *= ratio + + return scaled, orig_norm * ratio + + def bypass_forward_diff(self, x, scale=1): + if self.tucker: + mid = self.lora_mid(self.lora_down(x)) + else: + mid = self.lora_down(x) + + if self.rank_dropout and self.training: + drop = ( + torch.rand(self.lora_dim, device=mid.device) > self.rank_dropout + ).to(mid.dtype) + if self.rank_dropout_scale: + drop /= drop.mean() + if (dims := len(x.shape)) == 4: + drop = drop.view(1, -1, 1, 1) + else: + drop = drop.view(*[1] * (dims - 1), -1) + mid = mid * drop + + return self.dropout(self.lora_up(mid) * self.scalar * self.scale * scale) + + def bypass_forward(self, x, scale=1): + return self.org_forward(x) + self.bypass_forward_diff(x, scale=scale) + + def forward(self, x): + if self.module_dropout and self.training: + if torch.rand(1) < self.module_dropout: + return self.org_forward(x) + scale = self.scale + + dtype = self.dtype + if not self.bypass_mode: + diff_weight = self.make_weight(x.device).to(dtype) * scale + weight = self.org_module[0].weight.data.to(dtype) + if self.wd: + weight = self.apply_weight_decompose( + weight + diff_weight, self.multiplier + ) + else: + weight = weight + diff_weight * self.multiplier + bias = ( + None + if self.org_module[0].bias is None + else self.org_module[0].bias.data + ) + return self.op(x, weight, bias, **self.kw_dict) + else: + return self.bypass_forward(x, scale=self.multiplier) diff --git a/lycoris/modules/loha.py b/lycoris/modules/loha.py new file mode 100644 index 0000000000000000000000000000000000000000..54f8e52eb0b3837ec01e412dd562806653f8a3a8 --- /dev/null +++ b/lycoris/modules/loha.py @@ -0,0 +1,329 @@ +import math + +import torch +import torch.nn as nn + +from .base import LycorisBaseModule +from ..functional.loha import diff_weight as loha_diff_weight + + +class LohaModule(LycorisBaseModule): + name = "loha" + support_module = { + "linear", + "conv1d", + "conv2d", + "conv3d", + } + weight_list = [ + "hada_w1_a", + "hada_w1_b", + "hada_w2_a", + "hada_w2_b", + "hada_t1", + "hada_t2", + "alpha", + "dora_scale", + ] + weight_list_det = ["hada_w1_a"] + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=0.0, + rank_dropout=0.0, + module_dropout=0.0, + use_tucker=False, + use_scalar=False, + rank_dropout_scale=False, + weight_decompose=False, + wd_on_out=False, + bypass_mode=None, + rs_lora=False, + **kwargs, + ): + super().__init__( + lora_name, + org_module, + multiplier, + dropout, + rank_dropout, + module_dropout, + rank_dropout_scale, + bypass_mode, + ) + if self.module_type not in self.support_module: + raise ValueError(f"{self.module_type} is not supported in LoHa algo.") + self.lora_name = lora_name + self.lora_dim = lora_dim + self.tucker = False + self.rs_lora = rs_lora + + w_shape = self.shape + if self.module_type.startswith("conv"): + in_dim = org_module.in_channels + k_size = org_module.kernel_size + out_dim = org_module.out_channels + self.shape = (out_dim, in_dim, *k_size) + self.tucker = use_tucker and any(i != 1 for i in k_size) + if self.tucker: + w_shape = (out_dim, in_dim, *k_size) + else: + w_shape = (out_dim, in_dim * torch.tensor(k_size).prod().item()) + + if self.tucker: + self.hada_t1 = nn.Parameter(torch.empty(lora_dim, lora_dim, *w_shape[2:])) + self.hada_w1_a = nn.Parameter( + torch.empty(lora_dim, w_shape[0]) + ) # out_dim, 1-mode + self.hada_w1_b = nn.Parameter( + torch.empty(lora_dim, w_shape[1]) + ) # in_dim , 2-mode + + self.hada_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *w_shape[2:])) + self.hada_w2_a = nn.Parameter( + torch.empty(lora_dim, w_shape[0]) + ) # out_dim, 1-mode + self.hada_w2_b = nn.Parameter( + torch.empty(lora_dim, w_shape[1]) + ) # in_dim , 2-mode + else: + self.hada_w1_a = nn.Parameter(torch.empty(w_shape[0], lora_dim)) + self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, w_shape[1])) + + self.hada_w2_a = nn.Parameter(torch.empty(w_shape[0], lora_dim)) + self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, w_shape[1])) + + self.wd = weight_decompose + self.wd_on_out = wd_on_out + if self.wd: + org_weight = org_module.weight.cpu().clone().float() + self.dora_norm_dims = org_weight.dim() - 1 + if self.wd_on_out: + self.dora_scale = nn.Parameter( + torch.norm( + org_weight.reshape(org_weight.shape[0], -1), + dim=1, + keepdim=True, + ).reshape(org_weight.shape[0], *[1] * self.dora_norm_dims) + ).float() + else: + self.dora_scale = nn.Parameter( + torch.norm( + org_weight.transpose(1, 0).reshape(org_weight.shape[1], -1), + dim=1, + keepdim=True, + ) + .reshape(org_weight.shape[1], *[1] * self.dora_norm_dims) + .transpose(1, 0) + ).float() + + if self.dropout: + print("[WARN]LoHa/LoKr haven't implemented normal dropout yet.") + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = lora_dim if alpha is None or alpha == 0 else alpha + + r_factor = lora_dim + if self.rs_lora: + r_factor = math.sqrt(r_factor) + + self.scale = alpha / r_factor + + self.register_buffer("alpha", torch.tensor(alpha * (lora_dim / r_factor))) + + if use_scalar: + self.scalar = nn.Parameter(torch.tensor(0.0)) + else: + self.register_buffer("scalar", torch.tensor(1.0), persistent=False) + # Need more experiments on init method + if self.tucker: + torch.nn.init.normal_(self.hada_t1, std=0.1) + torch.nn.init.normal_(self.hada_t2, std=0.1) + torch.nn.init.normal_(self.hada_w1_b, std=1) + torch.nn.init.normal_(self.hada_w1_a, std=0.1) + torch.nn.init.normal_(self.hada_w2_b, std=1) + if use_scalar: + torch.nn.init.normal_(self.hada_w2_a, std=0.1) + else: + torch.nn.init.constant_(self.hada_w2_a, 0) + + @classmethod + def make_module_from_state_dict( + cls, lora_name, orig_module, w1a, w1b, w2a, w2b, t1, t2, alpha, dora_scale + ): + module = cls( + lora_name, + orig_module, + 1, + w1b.size(0), + float(alpha), + use_tucker=t1 is not None, + weight_decompose=dora_scale is not None, + ) + module.hada_w1_a.copy_(w1a) + module.hada_w1_b.copy_(w1b) + module.hada_w2_a.copy_(w2a) + module.hada_w2_b.copy_(w2b) + if t1 is not None: + module.hada_t1.copy_(t1) + module.hada_t2.copy_(t2) + if dora_scale is not None: + module.dora_scale.copy_(dora_scale) + return module + + def load_weight_hook(self, module: nn.Module, incompatible_keys): + missing_keys = incompatible_keys.missing_keys + for key in missing_keys: + if "scalar" in key: + del missing_keys[missing_keys.index(key)] + if isinstance(self.scalar, nn.Parameter): + self.scalar.data.copy_(torch.ones_like(self.scalar)) + elif getattr(self, "scalar", None) is not None: + self.scalar.copy_(torch.ones_like(self.scalar)) + else: + self.register_buffer( + "scalar", torch.ones_like(self.scalar), persistent=False + ) + + def get_weight(self, shape): + scale = torch.tensor( + self.scale, dtype=self.hada_w1_b.dtype, device=self.hada_w1_b.device + ) + if self.tucker: + weight = loha_diff_weight( + self.hada_w1_b, + self.hada_w1_a, + self.hada_w2_b, + self.hada_w2_a, + self.hada_t1, + self.hada_t2, + gamma=scale, + ) + else: + weight = loha_diff_weight( + self.hada_w1_b, + self.hada_w1_a, + self.hada_w2_b, + self.hada_w2_a, + None, + None, + gamma=scale, + ) + if shape is not None: + weight = weight.reshape(shape) + if self.training and self.rank_dropout: + drop = (torch.rand(weight.size(0)) > self.rank_dropout).to(weight.dtype) + drop = drop.view(-1, *[1] * len(weight.shape[1:])).to(weight.device) + if self.rank_dropout_scale: + drop /= drop.mean() + weight *= drop + return weight + + def get_diff_weight(self, multiplier=1, shape=None, device=None): + scale = self.scale * multiplier + diff = self.get_weight(shape) * scale + if device is not None: + diff = diff.to(device) + return diff, None + + def get_merged_weight(self, multiplier=1, shape=None, device=None): + diff = self.get_diff_weight(multiplier=1, shape=shape, device=device)[0] + weight = self.org_weight + if self.wd: + merged = self.apply_weight_decompose(weight + diff, multiplier) + else: + merged = weight + diff * multiplier + return merged, None + + def apply_weight_decompose(self, weight, multiplier=1): + weight = weight.to(self.dora_scale.dtype) + if self.wd_on_out: + weight_norm = ( + weight.reshape(weight.shape[0], -1) + .norm(dim=1) + .reshape(weight.shape[0], *[1] * self.dora_norm_dims) + ) + torch.finfo(weight.dtype).eps + else: + weight_norm = ( + weight.transpose(0, 1) + .reshape(weight.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight.shape[1], *[1] * self.dora_norm_dims) + .transpose(0, 1) + ) + torch.finfo(weight.dtype).eps + + scale = self.dora_scale.to(weight.device) / weight_norm + if multiplier != 1: + scale = multiplier * (scale - 1) + 1 + + return weight * scale + + def custom_state_dict(self): + destination = {} + destination["alpha"] = self.alpha + if self.wd: + destination["dora_scale"] = self.dora_scale + destination["hada_w1_a"] = self.hada_w1_a * self.scalar + destination["hada_w1_b"] = self.hada_w1_b + destination["hada_w2_a"] = self.hada_w2_a + destination["hada_w2_b"] = self.hada_w2_b + if self.tucker: + destination["hada_t1"] = self.hada_t1 + destination["hada_t2"] = self.hada_t2 + return destination + + @torch.no_grad() + def apply_max_norm(self, max_norm, device=None): + orig_norm = (self.get_weight(self.shape) * self.scalar).norm() + norm = torch.clamp(orig_norm, max_norm / 2) + desired = torch.clamp(norm, max=max_norm) + ratio = desired.cpu() / norm.cpu() + + scaled = norm != desired + if scaled: + self.scalar *= ratio + + return scaled, orig_norm * ratio + + def bypass_forward_diff(self, x, scale=1): + diff_weight = self.get_weight(self.shape) * self.scalar * scale + return self.drop(self.op(x, diff_weight, **self.kw_dict)) + + def bypass_forward(self, x, scale=1): + return self.org_forward(x) + self.bypass_forward_diff(x, scale=scale) + + def forward(self, x: torch.Tensor, *args, **kwargs): + if self.module_dropout and self.training: + if torch.rand(1) < self.module_dropout: + return self.op( + x, + self.org_module[0].weight.data, + ( + None + if self.org_module[0].bias is None + else self.org_module[0].bias.data + ), + ) + if self.bypass_mode: + return self.bypass_forward(x, scale=self.multiplier) + else: + diff_weight = self.get_weight(self.shape).to(self.dtype) * self.scalar + weight = self.org_module[0].weight.data.to(self.dtype) + if self.wd: + weight = self.apply_weight_decompose( + weight + diff_weight, self.multiplier + ) + else: + weight = weight + diff_weight * self.multiplier + bias = ( + None + if self.org_module[0].bias is None + else self.org_module[0].bias.data + ) + return self.op(x, weight, bias, **self.kw_dict) diff --git a/lycoris/modules/lokr.py b/lycoris/modules/lokr.py new file mode 100644 index 0000000000000000000000000000000000000000..12ecf42e97f3b377bc04b047dc6b623c8a6992d9 --- /dev/null +++ b/lycoris/modules/lokr.py @@ -0,0 +1,609 @@ +import math +from functools import cache + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base import LycorisBaseModule +from ..functional import factorization, rebuild_tucker +from ..functional.lokr import make_kron +from ..logging import logger + + +@cache +def logging_force_full_matrix(lora_dim, dim, factor): + logger.warning( + f"lora_dim {lora_dim} is too large for" + f" dim={dim} and {factor=}" + ", using full matrix mode." + ) + + +class LokrModule(LycorisBaseModule): + name = "kron" + support_module = { + "linear", + "conv1d", + "conv2d", + "conv3d", + } + weight_list = [ + "lokr_w1", + "lokr_w1_a", + "lokr_w1_b", + "lokr_w2", + "lokr_w2_a", + "lokr_w2_b", + "lokr_t1", + "lokr_t2", + "alpha", + "dora_scale", + ] + weight_list_det = ["lokr_w1", "lokr_w1_a"] + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=0.0, + rank_dropout=0.0, + module_dropout=0.0, + use_tucker=False, + use_scalar=False, + decompose_both=False, + factor: int = -1, # factorization factor + rank_dropout_scale=False, + weight_decompose=False, + wd_on_out=False, + full_matrix=False, + bypass_mode=None, + rs_lora=False, + unbalanced_factorization=False, + **kwargs, + ): + super().__init__( + lora_name, + org_module, + multiplier, + dropout, + rank_dropout, + module_dropout, + rank_dropout_scale, + bypass_mode, + ) + if self.module_type not in self.support_module: + raise ValueError(f"{self.module_type} is not supported in LoKr algo.") + + factor = int(factor) + self.lora_dim = lora_dim + self.tucker = False + self.use_w1 = False + self.use_w2 = False + self.full_matrix = full_matrix + self.rs_lora = rs_lora + + if self.module_type.startswith("conv"): + in_dim = org_module.in_channels + k_size = org_module.kernel_size + out_dim = org_module.out_channels + self.shape = (out_dim, in_dim, *k_size) + + in_m, in_n = factorization(in_dim, factor) + out_l, out_k = factorization(out_dim, factor) + if unbalanced_factorization: + out_l, out_k = out_k, out_l + shape = ((out_l, out_k), (in_m, in_n), *k_size) # ((a, b), (c, d), *k_size) + self.tucker = use_tucker and any(i != 1 for i in k_size) + if ( + decompose_both + and lora_dim < max(shape[0][0], shape[1][0]) / 2 + and not self.full_matrix + ): + self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim)) + self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0])) + else: + self.use_w1 = True + self.lokr_w1 = nn.Parameter( + torch.empty(shape[0][0], shape[1][0]) + ) # a*c, 1-mode + + if lora_dim >= max(shape[0][1], shape[1][1]) / 2 or self.full_matrix: + if not self.full_matrix: + logging_force_full_matrix(lora_dim, max(in_dim, out_dim), factor) + self.use_w2 = True + self.lokr_w2 = nn.Parameter( + torch.empty(shape[0][1], shape[1][1], *k_size) + ) + elif self.tucker: + self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *shape[2:])) + self.lokr_w2_a = nn.Parameter( + torch.empty(lora_dim, shape[0][1]) + ) # b, 1-mode + self.lokr_w2_b = nn.Parameter( + torch.empty(lora_dim, shape[1][1]) + ) # d, 2-mode + else: # Conv2d not tucker + # bigger part. weight and LoRA. [b, dim] x [dim, d*k1*k2] + self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim)) + self.lokr_w2_b = nn.Parameter( + torch.empty( + lora_dim, shape[1][1] * torch.tensor(shape[2:]).prod().item() + ) + ) + # w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d*k1*k2)) = (a, b)⊗(c, d*k1*k2) = (ac, bd*k1*k2) + else: # Linear + in_dim = org_module.in_features + out_dim = org_module.out_features + self.shape = (out_dim, in_dim) + + in_m, in_n = factorization(in_dim, factor) + out_l, out_k = factorization(out_dim, factor) + if unbalanced_factorization: + out_l, out_k = out_k, out_l + shape = ( + (out_l, out_k), + (in_m, in_n), + ) # ((a, b), (c, d)), out_dim = a*c, in_dim = b*d + # smaller part. weight scale + if ( + decompose_both + and lora_dim < max(shape[0][0], shape[1][0]) / 2 + and not self.full_matrix + ): + self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim)) + self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0])) + else: + self.use_w1 = True + self.lokr_w1 = nn.Parameter( + torch.empty(shape[0][0], shape[1][0]) + ) # a*c, 1-mode + if lora_dim < max(shape[0][1], shape[1][1]) / 2 and not self.full_matrix: + # bigger part. weight and LoRA. [b, dim] x [dim, d] + self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim)) + self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1])) + # w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d)) = (a, b)⊗(c, d) = (ac, bd) + else: + if not self.full_matrix: + logging_force_full_matrix(lora_dim, max(in_dim, out_dim), factor) + self.use_w2 = True + self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1])) + + self.wd = weight_decompose + self.wd_on_out = wd_on_out + if self.wd: + org_weight = org_module.weight.cpu().clone().float() + self.dora_norm_dims = org_weight.dim() - 1 + if self.wd_on_out: + self.dora_scale = nn.Parameter( + torch.norm( + org_weight.reshape(org_weight.shape[0], -1), + dim=1, + keepdim=True, + ).reshape(org_weight.shape[0], *[1] * self.dora_norm_dims) + ).float() + else: + self.dora_scale = nn.Parameter( + torch.norm( + org_weight.transpose(1, 0).reshape(org_weight.shape[1], -1), + dim=1, + keepdim=True, + ) + .reshape(org_weight.shape[1], *[1] * self.dora_norm_dims) + .transpose(1, 0) + ).float() + + self.dropout = dropout + if dropout: + print("[WARN]LoHa/LoKr haven't implemented normal dropout yet.") + self.rank_dropout = rank_dropout + self.rank_dropout_scale = rank_dropout_scale + self.module_dropout = module_dropout + + if isinstance(alpha, torch.Tensor): + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = lora_dim if alpha is None or alpha == 0 else alpha + if self.use_w2 and self.use_w1: + # use scale = 1 + alpha = lora_dim + + r_factor = lora_dim + if self.rs_lora: + r_factor = math.sqrt(r_factor) + + self.scale = alpha / r_factor + + self.register_buffer("alpha", torch.tensor(alpha * (lora_dim / r_factor))) + + if use_scalar: + self.scalar = nn.Parameter(torch.tensor(0.0)) + else: + self.register_buffer("scalar", torch.tensor(1.0), persistent=False) + + if self.use_w2: + if use_scalar: + torch.nn.init.kaiming_uniform_(self.lokr_w2, a=math.sqrt(5)) + else: + torch.nn.init.constant_(self.lokr_w2, 0) + else: + if self.tucker: + torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5)) + torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5)) + if use_scalar: + torch.nn.init.kaiming_uniform_(self.lokr_w2_b, a=math.sqrt(5)) + else: + torch.nn.init.constant_(self.lokr_w2_b, 0) + + if self.use_w1: + torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5)) + else: + torch.nn.init.kaiming_uniform_(self.lokr_w1_a, a=math.sqrt(5)) + torch.nn.init.kaiming_uniform_(self.lokr_w1_b, a=math.sqrt(5)) + + @classmethod + def make_module_from_state_dict( + cls, + lora_name, + orig_module, + w1, + w1a, + w1b, + w2, + w2a, + w2b, + _, + t2, + alpha, + dora_scale, + ): + full_matrix = False + if w1a is not None: + lora_dim = w1a.size(1) + elif w2a is not None: + lora_dim = w2a.size(1) + else: + full_matrix = True + lora_dim = 1 + + if w1 is None: + out_dim = w1a.size(0) + in_dim = w1b.size(1) + else: + out_dim, in_dim = w1.shape + + shape_s = [out_dim, in_dim] + + if w2 is None: + out_dim *= w2a.size(0) + in_dim *= w2b.size(1) + else: + out_dim *= w2.size(0) + in_dim *= w2.size(1) + + if ( + shape_s[0] == factorization(out_dim, -1)[0] + and shape_s[1] == factorization(in_dim, -1)[0] + ): + factor = -1 + else: + w1_shape = w1.shape if w1 is not None else (w1a.size(0), w1b.size(1)) + w2_shape = w2.shape if w2 is not None else (w2a.size(0), w2b.size(1)) + shape_group_1 = (w1_shape[0], w2_shape[0]) + shape_group_2 = (w1_shape[1], w2_shape[1]) + w_shape = (w1_shape[0] * w2_shape[0], w1_shape[1] * w2_shape[1]) + factor1 = max(w1.shape) if w1 is not None else max(w1a.size(0), w1b.size(1)) + factor2 = max(w2.shape) if w2 is not None else max(w2a.size(0), w2b.size(1)) + if ( + w_shape[0] % factor1 == 0 + and w_shape[1] % factor1 == 0 + and factor1 in shape_group_1 + and factor1 in shape_group_2 + ): + factor = factor1 + elif ( + w_shape[0] % factor2 == 0 + and w_shape[1] % factor2 == 0 + and factor2 in shape_group_1 + and factor2 in shape_group_2 + ): + factor = factor2 + else: + factor = min(factor1, factor2) + + module = cls( + lora_name, + orig_module, + 1, + lora_dim, + float(alpha), + use_tucker=t2 is not None, + decompose_both=w1 is None and w2 is None, + factor=factor, + weight_decompose=dora_scale is not None, + full_matrix=full_matrix, + ) + if w1 is not None: + module.lokr_w1.copy_(w1) + else: + module.lokr_w1_a.copy_(w1a) + module.lokr_w1_b.copy_(w1b) + if w2 is not None: + module.lokr_w2.copy_(w2) + else: + module.lokr_w2_a.copy_(w2a) + module.lokr_w2_b.copy_(w2b) + if t2 is not None: + module.lokr_t2.copy_(t2) + if dora_scale is not None: + module.dora_scale.copy_(dora_scale) + return module + + def load_weight_hook(self, module: nn.Module, incompatible_keys): + missing_keys = incompatible_keys.missing_keys + for key in missing_keys: + if "scalar" in key: + del missing_keys[missing_keys.index(key)] + if isinstance(self.scalar, nn.Parameter): + self.scalar.data.copy_(torch.ones_like(self.scalar)) + elif getattr(self, "scalar", None) is not None: + self.scalar.copy_(torch.ones_like(self.scalar)) + else: + self.register_buffer( + "scalar", torch.ones_like(self.scalar), persistent=False + ) + + def get_weight(self, shape): + weight = make_kron( + self.lokr_w1 if self.use_w1 else self.lokr_w1_a @ self.lokr_w1_b, + ( + self.lokr_w2 + if self.use_w2 + else ( + rebuild_tucker(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) + if self.tucker + else self.lokr_w2_a @ self.lokr_w2_b + ) + ), + self.scale, + ) + dtype = weight.dtype + if shape is not None: + weight = weight.view(shape) + if self.training and self.rank_dropout: + drop = (torch.rand(weight.size(0)) > self.rank_dropout).to(dtype) + drop = drop.view(-1, *[1] * len(weight.shape[1:])) + if self.rank_dropout_scale: + drop /= drop.mean() + weight *= drop + return weight + + def get_diff_weight(self, multiplier=1, shape=None, device=None): + scale = self.scale * multiplier + diff = self.get_weight(shape) * scale + if device is not None: + diff = diff.to(device) + return diff, None + + def get_merged_weight(self, multiplier=1, shape=None, device=None): + diff = self.get_diff_weight(multiplier=1, shape=shape, device=device)[0] + weight = self.org_weight + if self.wd: + merged = self.apply_weight_decompose(weight + diff, multiplier) + else: + merged = weight + diff * multiplier + return merged, None + + def apply_weight_decompose(self, weight, multiplier=1): + weight = weight.to(self.dora_scale.dtype) + if self.wd_on_out: + weight_norm = ( + weight.reshape(weight.shape[0], -1) + .norm(dim=1) + .reshape(weight.shape[0], *[1] * self.dora_norm_dims) + ) + torch.finfo(weight.dtype).eps + else: + weight_norm = ( + weight.transpose(0, 1) + .reshape(weight.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight.shape[1], *[1] * self.dora_norm_dims) + .transpose(0, 1) + ) + torch.finfo(weight.dtype).eps + + scale = self.dora_scale.to(weight.device) / weight_norm + if multiplier != 1: + scale = multiplier * (scale - 1) + 1 + + return weight * scale + + def custom_state_dict(self): + destination = {} + destination["alpha"] = self.alpha + if self.wd: + destination["dora_scale"] = self.dora_scale + if self.use_w1: + destination["lokr_w1"] = self.lokr_w1 * self.scalar + else: + destination["lokr_w1_a"] = self.lokr_w1_a * self.scalar + destination["lokr_w1_b"] = self.lokr_w1_b + + if self.use_w2: + destination["lokr_w2"] = self.lokr_w2 + else: + destination["lokr_w2_a"] = self.lokr_w2_a + destination["lokr_w2_b"] = self.lokr_w2_b + if self.tucker: + destination["lokr_t2"] = self.lokr_t2 + return destination + + @torch.no_grad() + def apply_max_norm(self, max_norm, device=None): + orig_norm = self.get_weight(self.shape).norm() + norm = torch.clamp(orig_norm, max_norm / 2) + desired = torch.clamp(norm, max=max_norm) + ratio = desired.cpu() / norm.cpu() + + scaled = norm != desired + if scaled: + modules = 4 - self.use_w1 - self.use_w2 + (not self.use_w2 and self.tucker) + if self.use_w1: + self.lokr_w1 *= ratio ** (1 / modules) + else: + self.lokr_w1_a *= ratio ** (1 / modules) + self.lokr_w1_b *= ratio ** (1 / modules) + + if self.use_w2: + self.lokr_w2 *= ratio ** (1 / modules) + else: + if self.tucker: + self.lokr_t2 *= ratio ** (1 / modules) + self.lokr_w2_a *= ratio ** (1 / modules) + self.lokr_w2_b *= ratio ** (1 / modules) + + return scaled, orig_norm * ratio + + def bypass_forward_diff(self, h, scale=1): + is_conv = self.module_type.startswith("conv") + if self.use_w2: + ba = self.lokr_w2 + else: + a = self.lokr_w2_b + b = self.lokr_w2_a + + if self.tucker: + t = self.lokr_t2 + a = a.view(*a.shape, *[1] * (len(t.shape) - 2)) + b = b.view(*b.shape, *[1] * (len(t.shape) - 2)) + elif is_conv: + a = a.view(*a.shape, *self.shape[2:]) + b = b.view(*b.shape, *[1] * (len(self.shape) - 2)) + + if self.use_w1: + c = self.lokr_w1 + else: + c = self.lokr_w1_a @ self.lokr_w1_b + uq = c.size(1) + + if is_conv: + # (b, uq), vq, ... + b, _, *rest = h.shape + h_in_group = h.reshape(b * uq, -1, *rest) + else: + # b, ..., uq, vq + h_in_group = h.reshape(*h.shape[:-1], uq, -1) + + if self.use_w2: + hb = self.op(h_in_group, ba, **self.kw_dict) + else: + if is_conv: + if self.tucker: + ha = self.op(h_in_group, a) + ht = self.op(ha, t, **self.kw_dict) + hb = self.op(ht, b) + else: + ha = self.op(h_in_group, a, **self.kw_dict) + hb = self.op(ha, b) + else: + ha = self.op(h_in_group, a, **self.kw_dict) + hb = self.op(ha, b) + + if is_conv: + # (b, uq), vp, ..., f + # -> b, uq, vp, ..., f + # -> b, f, vp, ..., uq + hb = hb.view(b, -1, *hb.shape[1:]) + h_cross_group = hb.transpose(1, -1) + else: + # b, ..., uq, vq + # -> b, ..., vq, uq + h_cross_group = hb.transpose(-1, -2) + + hc = F.linear(h_cross_group, c) + if is_conv: + # b, f, vp, ..., up + # -> b, up, vp, ... ,f + # -> b, c, ..., f + hc = hc.transpose(1, -1) + h = hc.reshape(b, -1, *hc.shape[3:]) + else: + # b, ..., vp, up + # -> b, ..., up, vp + # -> b, ..., c + hc = hc.transpose(-1, -2) + h = hc.reshape(*hc.shape[:-2], -1) + + return self.drop(h * scale * self.scalar) + + def bypass_forward(self, x, scale=1): + return self.org_forward(x) + self.bypass_forward_diff(x, scale=scale) + + def forward(self, x: torch.Tensor, *args, **kwargs): + if self.module_dropout and self.training: + if torch.rand(1) < self.module_dropout: + return self.org_forward(x) + if self.bypass_mode: + return self.bypass_forward(x, self.multiplier) + else: + diff_weight = self.get_weight(self.shape).to(self.dtype) * self.scalar + weight = self.org_module[0].weight.data.to(self.dtype) + if self.wd: + weight = self.apply_weight_decompose( + weight + diff_weight, self.multiplier + ) + elif self.multiplier == 1: + weight = weight + diff_weight + else: + weight = weight + diff_weight * self.multiplier + bias = ( + None + if self.org_module[0].bias is None + else self.org_module[0].bias.data + ) + return self.op(x, weight, bias, **self.kw_dict) + + +if __name__ == "__main__": + base = nn.Conv2d(128, 128, 3, 1, 1) + net = LokrModule( + "", + base, + multiplier=1, + lora_dim=4, + alpha=1, + weight_decompose=False, + use_tucker=False, + use_scalar=False, + decompose_both=True, + ) + net.apply_to() + sd = net.state_dict() + for key in sd: + if key != "alpha": + sd[key] = torch.randn_like(sd[key]) + net.load_state_dict(sd) + + test_input = torch.randn(1, 128, 16, 16) + test_output = net(test_input) + print(test_output.shape) + + net2 = LokrModule( + "", + base, + multiplier=1, + lora_dim=4, + alpha=1, + weight_decompose=False, + use_tucker=False, + use_scalar=False, + bypass_mode=True, + decompose_both=True, + ) + net2.apply_to() + net2.load_state_dict(sd) + print(net2) + + test_output2 = net(test_input) + print(F.mse_loss(test_output, test_output2)) diff --git a/lycoris/modules/norms.py b/lycoris/modules/norms.py new file mode 100644 index 0000000000000000000000000000000000000000..6ab627a6573d7910a251997e258c358fd76f1a87 --- /dev/null +++ b/lycoris/modules/norms.py @@ -0,0 +1,161 @@ +import torch +import torch.nn as nn + +from .base import LycorisBaseModule +from ..logging import warning_once + + +class NormModule(LycorisBaseModule): + name = "norm" + support_module = { + "layernorm", + "groupnorm", + } + weight_list = ["w_norm", "b_norm"] + weight_list_det = ["w_norm"] + + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + rank_dropout=0.0, + module_dropout=0.0, + rank_dropout_scale=False, + **kwargs, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__( + lora_name=lora_name, + org_module=org_module, + multiplier=multiplier, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + rank_dropout_scale=rank_dropout_scale, + **kwargs, + ) + if self.module_type == "unknown": + if not hasattr(org_module, "weight") or not hasattr(org_module, "_norm"): + warning_once(f"{type(org_module)} is not supported in Norm algo.") + self.not_supported = True + return + else: + self.dim = org_module.weight.numel() + self.not_supported = False + elif self.module_type not in self.support_module: + warning_once(f"{self.module_type} is not supported in Norm algo.") + self.not_supported = True + return + + self.w_norm = nn.Parameter(torch.zeros(self.dim)) + if hasattr(org_module, "bias"): + self.b_norm = nn.Parameter(torch.zeros(self.dim)) + if hasattr(org_module, "_norm"): + self.org_norm = org_module._norm + else: + self.org_norm = None + + @classmethod + def make_module_from_state_dict(cls, lora_name, orig_module, w_norm, b_norm): + module = cls( + lora_name, + orig_module, + 1, + ) + module.w_norm.copy_(w_norm) + if b_norm is not None: + module.b_norm.copy_(b_norm) + return module + + def make_weight(self, scale=1, device=None): + org_weight = self.org_module[0].weight.to(device, dtype=self.w_norm.dtype) + if hasattr(self.org_module[0], "bias"): + org_bias = self.org_module[0].bias.to(device, dtype=self.b_norm.dtype) + else: + org_bias = None + if self.rank_dropout and self.training: + drop = (torch.rand(self.dim, device=device) < self.rank_dropout).to( + self.w_norm.device + ) + if self.rank_dropout_scale: + drop /= drop.mean() + else: + drop = 1 + drop = ( + torch.rand(self.dim, device=device) < self.rank_dropout + if self.rank_dropout and self.training + else 1 + ) + weight = self.w_norm.to(device) * drop * scale + if org_bias is not None: + bias = self.b_norm.to(device) * drop * scale + return org_weight + weight, org_bias + bias if org_bias is not None else None + + def get_diff_weight(self, multiplier=1, shape=None, device=None): + if self.not_supported: + return 0, 0 + w = self.w_norm * multiplier + if device is not None: + w = w.to(device) + if shape is not None: + w = w.view(shape) + if self.b_norm is not None: + b = self.b_norm * multiplier + if device is not None: + b = b.to(device) + if shape is not None: + b = b.view(shape) + else: + b = None + return w, b + + def get_merged_weight(self, multiplier=1, shape=None, device=None): + if self.not_supported: + return None, None + diff_w, diff_b = self.get_diff_weight(multiplier, shape, device) + org_w = self.org_module[0].weight.to(device, dtype=self.w_norm.dtype) + weight = org_w + diff_w + if diff_b is not None: + org_b = self.org_module[0].bias.to(device, dtype=self.b_norm.dtype) + bias = org_b + diff_b + else: + bias = None + return weight, bias + + def forward(self, x): + if self.not_supported or ( + self.module_dropout + and self.training + and torch.rand(1) < self.module_dropout + ): + return self.org_forward(x) + scale = self.multiplier + + w, b = self.make_weight(scale, x.device) + if self.org_norm is not None: + normed = self.org_norm(x) + scaled = normed * w + if b is not None: + scaled += b + return scaled + + kw_dict = self.kw_dict | {"weight": w, "bias": b} + return self.op(x, **kw_dict) + + +if __name__ == "__main__": + base = nn.LayerNorm(128).cuda() + norm = NormModule("test", base, 1).cuda() + print(norm) + test_input = torch.randn(1, 128).cuda() + test_output = norm(test_input) + torch.sum(test_output).backward() + print(test_output.shape) + + base = nn.GroupNorm(4, 128).cuda() + norm = NormModule("test", base, 1).cuda() + print(norm) + test_input = torch.randn(1, 128, 3, 3).cuda() + test_output = norm(test_input) + torch.sum(test_output).backward() + print(test_output.shape) diff --git a/lycoris/utils/__init__.py b/lycoris/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1acf0e129e5168ba917a83707c66507898f76e3 --- /dev/null +++ b/lycoris/utils/__init__.py @@ -0,0 +1,483 @@ +import re +import hashlib +from io import BytesIO +from typing import Dict, Tuple, Union + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.linalg as linalg + +import safetensors.torch + +from tqdm import tqdm +from .general import * + + +def load_bytes_in_safetensors(tensors): + bytes = safetensors.torch.save(tensors) + b = BytesIO(bytes) + + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") + + offset = n + 8 + b.seek(offset) + + return b.read() + + +def precalculate_safetensors_hashes(state_dict): + # calculate each tensor one by one to reduce memory usage + hash_sha256 = hashlib.sha256() + for tensor in state_dict.values(): + single_tensor_sd = {"tensor": tensor} + bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd) + hash_sha256.update(bytes_for_tensor) + + return f"0x{hash_sha256.hexdigest()}" + + +def str_bool(val): + return str(val).lower() != "false" + + +def default(val, d): + return val if val is not None else d + + +def make_sparse(t: torch.Tensor, sparsity=0.95): + abs_t = torch.abs(t) + np_array = abs_t.detach().cpu().numpy() + quan = float(np.quantile(np_array, sparsity)) + sparse_t = t.masked_fill(abs_t < quan, 0) + return sparse_t + + +def extract_conv( + weight: Union[torch.Tensor, nn.Parameter], + mode="fixed", + mode_param=0, + device="cpu", + is_cp=False, +) -> Tuple[nn.Parameter, nn.Parameter]: + weight = weight.to(device) + out_ch, in_ch, kernel_size, _ = weight.shape + + U, S, Vh = linalg.svd(weight.reshape(out_ch, -1)) + + if mode == "full": + return weight, "full" + elif mode == "fixed": + lora_rank = mode_param + elif mode == "threshold": + assert mode_param >= 0 + lora_rank = torch.sum(S > mode_param) + elif mode == "ratio": + assert 1 >= mode_param >= 0 + min_s = torch.max(S) * mode_param + lora_rank = torch.sum(S > min_s) + elif mode == "quantile" or mode == "percentile": + assert 1 >= mode_param >= 0 + s_cum = torch.cumsum(S, dim=0) + min_cum_sum = mode_param * torch.sum(S) + lora_rank = torch.sum(s_cum < min_cum_sum) + else: + raise NotImplementedError( + 'Extract mode should be "fixed", "threshold", "ratio" or "quantile"' + ) + lora_rank = max(1, lora_rank) + lora_rank = min(out_ch, in_ch, lora_rank) + if lora_rank >= out_ch / 2 and not is_cp: + return weight, "full" + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S).to(device) + Vh = Vh[:lora_rank, :] + + diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach() + extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach() + extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach() + del U, S, Vh, weight + return (extract_weight_A, extract_weight_B, diff), "low rank" + + +def extract_linear( + weight: Union[torch.Tensor, nn.Parameter], + mode="fixed", + mode_param=0, + device="cpu", +) -> Tuple[nn.Parameter, nn.Parameter]: + weight = weight.to(device) + out_ch, in_ch = weight.shape + + U, S, Vh = linalg.svd(weight) + + if mode == "full": + return weight, "full" + elif mode == "fixed": + lora_rank = mode_param + elif mode == "threshold": + assert mode_param >= 0 + lora_rank = torch.sum(S > mode_param) + elif mode == "ratio": + assert 1 >= mode_param >= 0 + min_s = torch.max(S) * mode_param + lora_rank = torch.sum(S > min_s) + elif mode == "quantile" or mode == "percentile": + assert 1 >= mode_param >= 0 + s_cum = torch.cumsum(S, dim=0) + min_cum_sum = mode_param * torch.sum(S) + lora_rank = torch.sum(s_cum < min_cum_sum) + else: + raise NotImplementedError( + 'Extract mode should be "fixed", "threshold", "ratio" or "quantile"' + ) + lora_rank = max(1, lora_rank) + lora_rank = min(out_ch, in_ch, lora_rank) + if lora_rank >= out_ch / 2: + return weight, "full" + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S).to(device) + Vh = Vh[:lora_rank, :] + + diff = (weight - U @ Vh).detach() + extract_weight_A = Vh.reshape(lora_rank, in_ch).detach() + extract_weight_B = U.reshape(out_ch, lora_rank).detach() + del U, S, Vh, weight + return (extract_weight_A, extract_weight_B, diff), "low rank" + + +@torch.no_grad() +def extract_diff( + base_tes, + db_tes, + base_unet, + db_unet, + mode="fixed", + linear_mode_param=0, + conv_mode_param=0, + extract_device="cpu", + use_bias=False, + sparsity=0.98, + small_conv=True, +): + UNET_TARGET_REPLACE_MODULE = [ + "Linear", + "Conv2d", + "LayerNorm", + "GroupNorm", + "GroupNorm32", + ] + TEXT_ENCODER_TARGET_REPLACE_MODULE = [ + "Embedding", + "Linear", + "Conv2d", + "LayerNorm", + "GroupNorm", + "GroupNorm32", + ] + LORA_PREFIX_UNET = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + + def make_state_dict( + prefix, + root_module: torch.nn.Module, + target_module: torch.nn.Module, + target_replace_modules, + ): + loras = {} + temp = {} + + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + temp[name] = module + + for name, module in tqdm( + list((n, m) for n, m in target_module.named_modules() if n in temp) + ): + weights = temp[name] + lora_name = prefix + "." + name + lora_name = lora_name.replace(".", "_") + layer = module.__class__.__name__ + + if layer in { + "Linear", + "Conv2d", + "LayerNorm", + "GroupNorm", + "GroupNorm32", + "Embedding", + }: + root_weight = module.weight + if torch.allclose(root_weight, weights.weight): + continue + else: + continue + module = module.to(extract_device) + weights = weights.to(extract_device) + + if mode == "full": + decompose_mode = "full" + elif layer == "Linear": + weight, decompose_mode = extract_linear( + (root_weight - weights.weight), + mode, + linear_mode_param, + device=extract_device, + ) + if decompose_mode == "low rank": + extract_a, extract_b, diff = weight + elif layer == "Conv2d": + is_linear = root_weight.shape[2] == 1 and root_weight.shape[3] == 1 + weight, decompose_mode = extract_conv( + (root_weight - weights.weight), + mode, + linear_mode_param if is_linear else conv_mode_param, + device=extract_device, + ) + if decompose_mode == "low rank": + extract_a, extract_b, diff = weight + if small_conv and not is_linear and decompose_mode == "low rank": + dim = extract_a.size(0) + (extract_c, extract_a, _), _ = extract_conv( + extract_a.transpose(0, 1), + "fixed", + dim, + extract_device, + True, + ) + extract_a = extract_a.transpose(0, 1) + extract_c = extract_c.transpose(0, 1) + loras[f"{lora_name}.lora_mid.weight"] = ( + extract_c.detach().cpu().contiguous().half() + ) + diff = ( + ( + root_weight + - torch.einsum( + "i j k l, j r, p i -> p r k l", + extract_c, + extract_a.flatten(1, -1), + extract_b.flatten(1, -1), + ) + ) + .detach() + .cpu() + .contiguous() + ) + del extract_c + else: + module = module.to("cpu") + weights = weights.to("cpu") + continue + + if decompose_mode == "low rank": + loras[f"{lora_name}.lora_down.weight"] = ( + extract_a.detach().cpu().contiguous().half() + ) + loras[f"{lora_name}.lora_up.weight"] = ( + extract_b.detach().cpu().contiguous().half() + ) + loras[f"{lora_name}.alpha"] = torch.Tensor([extract_a.shape[0]]).half() + if use_bias: + diff = diff.detach().cpu().reshape(extract_b.size(0), -1) + sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce() + + indices = sparse_diff.indices().to(torch.int16) + values = sparse_diff.values().half() + loras[f"{lora_name}.bias_indices"] = indices + loras[f"{lora_name}.bias_values"] = values + loras[f"{lora_name}.bias_size"] = torch.tensor(diff.shape).to( + torch.int16 + ) + del extract_a, extract_b, diff + elif decompose_mode == "full": + if "Norm" in layer: + w_key = "w_norm" + b_key = "b_norm" + else: + w_key = "diff" + b_key = "diff_b" + weight_diff = module.weight - weights.weight + loras[f"{lora_name}.{w_key}"] = ( + weight_diff.detach().cpu().contiguous().half() + ) + if getattr(weights, "bias", None) is not None: + bias_diff = module.bias - weights.bias + loras[f"{lora_name}.{b_key}"] = ( + bias_diff.detach().cpu().contiguous().half() + ) + else: + raise NotImplementedError + module = module.to("cpu") + weights = weights.to("cpu") + return loras + + all_loras = {} + + all_loras |= make_state_dict( + LORA_PREFIX_UNET, + base_unet, + db_unet, + UNET_TARGET_REPLACE_MODULE, + ) + del base_unet, db_unet + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + for idx, (te1, te2) in enumerate(zip(base_tes, db_tes)): + if len(base_tes) > 1: + prefix = f"{LORA_PREFIX_TEXT_ENCODER}{idx+1}" + else: + prefix = LORA_PREFIX_TEXT_ENCODER + all_loras |= make_state_dict( + prefix, + te1, + te2, + TEXT_ENCODER_TARGET_REPLACE_MODULE, + ) + del te1, te2 + + all_lora_name = set() + for k in all_loras: + lora_name, weight = k.rsplit(".", 1) + all_lora_name.add(lora_name) + print(len(all_lora_name)) + return all_loras + + +re_digits = re.compile(r"\d+") +re_compiled = {} + +suffix_conversion = { + "attentions": {}, + "resnets": { + "conv1": "in_layers_2", + "conv2": "out_layers_3", + "norm1": "in_layers_0", + "norm2": "out_layers_0", + "time_emb_proj": "emb_layers_1", + "conv_shortcut": "skip_connection", + }, +} + + +def convert_diffusers_name_to_compvis(key): + def match(match_list, regex_text): + regex = re_compiled.get(regex_text) + if regex is None: + regex = re.compile(regex_text) + re_compiled[regex_text] = regex + + r = re.match(regex, key) + if not r: + return False + + match_list.clear() + match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) + return True + + m = [] + + if match(m, r"lora_unet_conv_in(.*)"): + return f"lora_unet_input_blocks_0_0{m[0]}" + + if match(m, r"lora_unet_conv_out(.*)"): + return f"lora_unet_out_2{m[0]}" + + if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"): + return f"lora_unet_time_embed_{m[0] * 2 - 2}{m[1]}" + + if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"lora_unet_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" + + if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) + return ( + f"lora_unet_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" + ) + + if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"lora_unet_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" + + if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): + return f"lora_unet_input_blocks_{3 + m[0] * 3}_0_op" + + if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): + return f"lora_unet_output_blocks_{2 + m[0] * 3}_2_conv" + return key + + +@torch.no_grad() +def merge(tes, unet, lyco_state_dict, scale: float = 1.0, device="cpu"): + from ..modules import make_module, get_module + + LORA_PREFIX_UNET = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + merged = 0 + + def merge_state_dict( + prefix, + root_module: torch.nn.Module, + lyco_state_dict: Dict[str, torch.Tensor], + ): + nonlocal merged + for child_name, child_module in tqdm( + list(root_module.named_modules()), desc=f"Merging {prefix}" + ): + lora_name = prefix + "." + child_name + lora_name = lora_name.replace(".", "_") + + lyco_type, params = get_module(lyco_state_dict, lora_name) + if lyco_type is None: + continue + module = make_module(lyco_type, params, lora_name, child_module) + if module is None: + continue + module.to(device) + module.merge_to(scale) + key_dict.pop(convert_diffusers_name_to_compvis(lora_name), None) + key_dict.pop(lora_name, None) + merged += 1 + + key_dict = {} + for k, v in tqdm(list(lyco_state_dict.items()), desc="Converting Dtype and Device"): + module, weight_key = k.split(".", 1) + convert_key = convert_diffusers_name_to_compvis(module) + if convert_key != module and len(tes) > 1: + # kohya's format for sdxl is as same as SGM, not diffusers + del lyco_state_dict[k] + key_dict[convert_key] = key_dict.get(convert_key, []) + [k] + k = f"{convert_key}.{weight_key}" + else: + key_dict[module] = key_dict.get(module, []) + [k] + lyco_state_dict[k] = v.float().cpu() + + for idx, te in enumerate(tes): + if len(tes) > 1: + prefix = LORA_PREFIX_TEXT_ENCODER + str(idx + 1) + else: + prefix = LORA_PREFIX_TEXT_ENCODER + merge_state_dict( + prefix, + te, + lyco_state_dict, + ) + torch.cuda.empty_cache() + merge_state_dict( + LORA_PREFIX_UNET, + unet, + lyco_state_dict, + ) + torch.cuda.empty_cache() + print(f"Unused state dict key: {key_dict}") + print(f"{merged} Modules been merged") diff --git a/lycoris/utils/general.py b/lycoris/utils/general.py new file mode 100644 index 0000000000000000000000000000000000000000..ccfaaa133f3c21b64637dedb4277a716f802a96e --- /dev/null +++ b/lycoris/utils/general.py @@ -0,0 +1,5 @@ +def product(xs: list[int | float]): + res = 1 + for x in xs: + res *= x + return res diff --git a/lycoris/utils/logger.py b/lycoris/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..e3f51e575e982546c6d891b7511743514548b54e --- /dev/null +++ b/lycoris/utils/logger.py @@ -0,0 +1,35 @@ +import logging +import copy +import sys + + +class ColoredFormatter(logging.Formatter): + COLORS = { + "DEBUG": "\033[0;36m", # CYAN + "INFO": "\033[0;32m", # GREEN + "WARNING": "\033[0;33m", # YELLOW + "ERROR": "\033[0;31m", # RED + "CRITICAL": "\033[0;37;41m", # WHITE ON RED + "RESET": "\033[0m", # RESET COLOR + } + + def format(self, record): + colored_record = copy.copy(record) + levelname = colored_record.levelname + seq = self.COLORS.get(levelname, self.COLORS["RESET"]) + colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}" + return super().format(colored_record) + + +# Create a new logger +logger = logging.getLogger("LyCORIS") +logger.propagate = False + +# Add handler if we don't have one. +if not logger.handlers: + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(ColoredFormatter("[%(name)s]-%(levelname)s: %(message)s")) + logger.addHandler(handler) + +logger.setLevel(logging.DEBUG) +logger.debug("Logger initialized.") diff --git a/lycoris/utils/preset.py b/lycoris/utils/preset.py new file mode 100644 index 0000000000000000000000000000000000000000..f6ec844a1a7063e51da236255c8807d7624b0e7a --- /dev/null +++ b/lycoris/utils/preset.py @@ -0,0 +1,9 @@ +import toml + + +def read_preset(preset): + try: + return toml.load(preset) + except Exception as e: + print("Error: cannot read preset file. ", e) + return None diff --git a/lycoris/utils/quant.py b/lycoris/utils/quant.py new file mode 100644 index 0000000000000000000000000000000000000000..aa4660e0a5d9a2b460f16c1beb6949e3734ae88d --- /dev/null +++ b/lycoris/utils/quant.py @@ -0,0 +1,88 @@ +from functools import cache + +SUPPORT_QUANT = False +try: + from bitsandbytes.nn import LinearNF4, Linear8bitLt, LinearFP4 + + SUPPORT_QUANT = True +except Exception: + import torch.nn as nn + + class LinearNF4(nn.Linear): + pass + + class Linear8bitLt(nn.Linear): + pass + + class LinearFP4(nn.Linear): + pass + + +try: + from quanto.nn import QLinear, QConv2d, QLayerNorm + + SUPPORT_QUANT = True +except Exception: + import torch.nn as nn + + class QLinear(nn.Linear): + pass + + class QConv2d(nn.Conv2d): + pass + + class QLayerNorm(nn.LayerNorm): + pass + + +try: + from optimum.quanto.nn import ( + QLinear as QLinearOpt, + QConv2d as QConv2dOpt, + QLayerNorm as QLayerNormOpt, + ) + + SUPPORT_QUANT = True +except Exception: + import torch.nn as nn + + class QLinearOpt(nn.Linear): + pass + + class QConv2dOpt(nn.Conv2d): + pass + + class QLayerNormOpt(nn.LayerNorm): + pass + + +from ..logging import logger + + +QuantLinears = ( + Linear8bitLt, + LinearFP4, + LinearNF4, + QLinear, + QConv2d, + QLayerNorm, + QLinearOpt, + QConv2dOpt, + QLayerNormOpt, +) + + +@cache +def log_bypass(): + return logger.warning( + "Using bnb/quanto/optimum-quanto with LyCORIS will enable force-bypass mode." + ) + + +@cache +def log_suspect(): + return logger.warning( + "Non-native Linear detected but bypass_mode is not set. " + "Automatically using force-bypass mode to avoid possible issues. " + "Please set bypass_mode=False explicitly if there are no quantized layers." + ) diff --git a/lycoris/utils/xformers_utils.py b/lycoris/utils/xformers_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..31f737ceafa7f6743100d3b64677cac08c1811fb --- /dev/null +++ b/lycoris/utils/xformers_utils.py @@ -0,0 +1,13 @@ +memory_efficient_attention = None +try: + import xformers +except Exception: + pass + +try: + from xformers.ops import memory_efficient_attention + + XFORMERS_AVAIL = True +except Exception: + memory_efficient_attention = None + XFORMERS_AVAIL = False diff --git a/lycoris/wrapper.py b/lycoris/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..3d53a11f3c045e89a1967d401ed352829ddf6655 --- /dev/null +++ b/lycoris/wrapper.py @@ -0,0 +1,640 @@ +# General LyCORIS wrapper based on kohya-ss/sd-scripts' style +import os +import fnmatch +import re +import logging + +from typing import Any, List + +import torch +import torch.nn as nn + +from .modules.locon import LoConModule +from .modules.loha import LohaModule +from .modules.lokr import LokrModule +from .modules.dylora import DyLoraModule +from .modules.glora import GLoRAModule +from .modules.norms import NormModule +from .modules.full import FullModule +from .modules.diag_oft import DiagOFTModule +from .modules.boft import ButterflyOFTModule +from .modules import get_module, make_module + +from .config import PRESET +from .utils.preset import read_preset +from .utils import str_bool +from .logging import logger + + +VALID_PRESET_KEYS = [ + "enable_conv", + "target_module", + "target_name", + "module_algo_map", + "name_algo_map", + "lora_prefix", + "use_fnmatch", + "unet_target_module", + "unet_target_name", + "text_encoder_target_module", + "text_encoder_target_name", + "exclude_name", +] + + +network_module_dict = { + "lora": LoConModule, + "locon": LoConModule, + "loha": LohaModule, + "lokr": LokrModule, + "dylora": DyLoraModule, + "glora": GLoRAModule, + "full": FullModule, + "diag-oft": DiagOFTModule, + "boft": ButterflyOFTModule, +} +deprecated_arg_dict = { + "disable_conv_cp": "use_tucker", + "use_cp": "use_tucker", + "use_conv_cp": "use_tucker", + "constrain": "constraint", +} + + +def create_lycoris(module, multiplier=1.0, linear_dim=4, linear_alpha=1, **kwargs): + for key, value in list(kwargs.items()): + if key in deprecated_arg_dict: + logger.warning( + f"{key} is deprecated. Please use {deprecated_arg_dict[key]} instead.", + stacklevel=2, + ) + kwargs[deprecated_arg_dict[key]] = value + if linear_dim is None: + linear_dim = 4 # default + conv_dim = int(kwargs.get("conv_dim", linear_dim) or linear_dim) + conv_alpha = float(kwargs.get("conv_alpha", linear_alpha) or linear_alpha) + dropout = float(kwargs.get("dropout", 0.0) or 0.0) + rank_dropout = float(kwargs.get("rank_dropout", 0.0) or 0.0) + module_dropout = float(kwargs.get("module_dropout", 0.0) or 0.0) + algo = (kwargs.get("algo", "lora") or "lora").lower() + use_tucker = str_bool( + not kwargs.get("disable_conv_cp", True) + or kwargs.get("use_conv_cp", False) + or kwargs.get("use_cp", False) + or kwargs.get("use_tucker", False) + ) + use_scalar = str_bool(kwargs.get("use_scalar", False)) + block_size = int(kwargs.get("block_size", 4) or 4) + train_norm = str_bool(kwargs.get("train_norm", False)) + constraint = float(kwargs.get("constraint", 0) or 0) + rescaled = str_bool(kwargs.get("rescaled", False)) + weight_decompose = str_bool(kwargs.get("dora_wd", False)) + wd_on_output = str_bool(kwargs.get("wd_on_output", False)) + full_matrix = str_bool(kwargs.get("full_matrix", False)) + bypass_mode = str_bool(kwargs.get("bypass_mode", None)) + unbalanced_factorization = str_bool(kwargs.get("unbalanced_factorization", False)) + + if unbalanced_factorization: + logger.info("Unbalanced factorization for LoKr is enabled") + + if bypass_mode: + logger.info("Bypass mode is enabled") + + if weight_decompose: + logger.info("Weight decomposition is enabled") + + if full_matrix: + logger.info("Full matrix mode for LoKr is enabled") + + preset = kwargs.get("preset", "full") + if preset not in PRESET: + preset = read_preset(preset) + else: + preset = PRESET[preset] + assert preset is not None + LycorisNetwork.apply_preset(preset) + + logger.info(f"Using rank adaptation algo: {algo}") + + network = LycorisNetwork( + module, + multiplier=multiplier, + lora_dim=linear_dim, + conv_lora_dim=conv_dim, + alpha=linear_alpha, + conv_alpha=conv_alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + use_tucker=use_tucker, + use_scalar=use_scalar, + network_module=algo, + train_norm=train_norm, + decompose_both=kwargs.get("decompose_both", False), + factor=kwargs.get("factor", -1), + block_size=block_size, + constraint=constraint, + rescaled=rescaled, + weight_decompose=weight_decompose, + wd_on_out=wd_on_output, + full_matrix=full_matrix, + bypass_mode=bypass_mode, + unbalanced_factorization=unbalanced_factorization, + ) + + return network + + +def create_lycoris_from_weights(multiplier, file, module, weights_sd=None, **kwargs): + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping + loras = {} + for key in weights_sd: + if "." not in key: + continue + + lora_name = key.split(".")[0] + loras[lora_name] = None + + for name, modules in module.named_modules(): + lora_name = f"{LycorisNetwork.LORA_PREFIX}_{name}".replace(".", "_") + if lora_name in loras: + loras[lora_name] = modules + + original_level = logger.level + logger.setLevel(logging.ERROR) + network = LycorisNetwork(module, init_only=True) + network.multiplier = multiplier + network.loras = [] + logger.setLevel(original_level) + + logger.info("Loading Modules from state dict...") + for lora_name, orig_modules in loras.items(): + if orig_modules is None: + continue + lyco_type, params = get_module(weights_sd, lora_name) + module = make_module(lyco_type, params, lora_name, orig_modules) + if module is not None: + network.loras.append(module) + network.algo_table[module.__class__.__name__] = ( + network.algo_table.get(module.__class__.__name__, 0) + 1 + ) + logger.info(f"{len(network.loras)} Modules Loaded") + + for lora in network.loras: + lora.multiplier = multiplier + + return network, weights_sd + + +class LycorisNetwork(torch.nn.Module): + ENABLE_CONV = True + TARGET_REPLACE_MODULE = [ + "Linear", + "Conv1d", + "Conv2d", + "Conv3d", + "GroupNorm", + "LayerNorm", + ] + TARGET_REPLACE_NAME = [] + LORA_PREFIX = "lycoris" + MODULE_ALGO_MAP = {} + NAME_ALGO_MAP = {} + USE_FNMATCH = False + TARGET_EXCLUDE_NAME = [] + + @classmethod + def apply_preset(cls, preset): + for preset_key in preset.keys(): + if preset_key not in VALID_PRESET_KEYS: + raise KeyError( + f'Unknown preset key "{preset_key}". Valid keys: {VALID_PRESET_KEYS}' + ) + + if "enable_conv" in preset: + cls.ENABLE_CONV = preset["enable_conv"] + if "target_module" in preset: + cls.TARGET_REPLACE_MODULE = preset["target_module"] + if "target_name" in preset: + cls.TARGET_REPLACE_NAME = preset["target_name"] + if "module_algo_map" in preset: + cls.MODULE_ALGO_MAP = preset["module_algo_map"] + if "name_algo_map" in preset: + cls.NAME_ALGO_MAP = preset["name_algo_map"] + if "lora_prefix" in preset: + cls.LORA_PREFIX = preset["lora_prefix"] + if "use_fnmatch" in preset: + cls.USE_FNMATCH = preset["use_fnmatch"] + if "exclude_name" in preset: + cls.TARGET_EXCLUDE_NAME = preset["exclude_name"] + return cls + + def __init__( + self, + module: nn.Module, + multiplier=1.0, + lora_dim=4, + conv_lora_dim=4, + alpha=1, + conv_alpha=1, + use_tucker=False, + dropout=0, + rank_dropout=0, + module_dropout=0, + network_module: str = "locon", + norm_modules=NormModule, + train_norm=False, + init_only=False, + **kwargs, + ) -> None: + super().__init__() + root_kwargs = kwargs + self.weights_sd = None + if init_only: + self.multiplier = 1 + self.lora_dim = 0 + self.alpha = 1 + self.conv_lora_dim = 0 + self.conv_alpha = 1 + self.dropout = 0 + self.rank_dropout = 0 + self.module_dropout = 0 + self.use_tucker = False + self.loras = [] + self.algo_table = {} + return + self.multiplier = multiplier + self.lora_dim = lora_dim + + if not self.ENABLE_CONV: + conv_lora_dim = 0 + + self.conv_lora_dim = int(conv_lora_dim) + if self.conv_lora_dim and self.conv_lora_dim != self.lora_dim: + logger.info("Apply different lora dim for conv layer") + logger.info(f"Conv Dim: {conv_lora_dim}, Linear Dim: {lora_dim}") + elif self.conv_lora_dim == 0: + logger.info("Disable conv layer") + + self.alpha = alpha + self.conv_alpha = float(conv_alpha) + if self.conv_lora_dim and self.alpha != self.conv_alpha: + logger.info("Apply different alpha value for conv layer") + logger.info(f"Conv alpha: {conv_alpha}, Linear alpha: {alpha}") + + if 1 >= dropout >= 0: + logger.info(f"Use Dropout value: {dropout}") + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + self.use_tucker = use_tucker + + def create_single_module( + lora_name: str, + module: torch.nn.Module, + algo_name, + dim=None, + alpha=None, + use_tucker=self.use_tucker, + **kwargs, + ): + for k, v in root_kwargs.items(): + if k in kwargs: + continue + kwargs[k] = v + + if train_norm and "Norm" in module.__class__.__name__: + return norm_modules( + lora_name, + module, + self.multiplier, + self.rank_dropout, + self.module_dropout, + **kwargs, + ) + lora = None + if isinstance(module, torch.nn.Linear) and lora_dim > 0: + dim = dim or lora_dim + alpha = alpha or self.alpha + elif isinstance( + module, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d) + ): + k_size, *_ = module.kernel_size + if k_size == 1 and lora_dim > 0: + dim = dim or lora_dim + alpha = alpha or self.alpha + elif conv_lora_dim > 0 or dim: + dim = dim or conv_lora_dim + alpha = alpha or self.conv_alpha + else: + return None + else: + return None + lora = network_module_dict[algo_name]( + lora_name, + module, + self.multiplier, + dim, + alpha, + self.dropout, + self.rank_dropout, + self.module_dropout, + use_tucker, + **kwargs, + ) + return lora + + def create_modules_( + prefix: str, + root_module: torch.nn.Module, + algo, + current_lora_map: dict[str, Any], + configs={}, + ): + assert current_lora_map is not None, "No mapping supplied" + loras = current_lora_map + lora_names = [] + for name, module in root_module.named_modules(): + module_name = module.__class__.__name__ + if module_name in self.MODULE_ALGO_MAP and module is not root_module: + next_config = self.MODULE_ALGO_MAP[module_name] + next_algo = next_config.get("algo", algo) + new_loras, new_lora_names, new_lora_map = create_modules_( + f"{prefix}_{name}" if name else prefix, + module, + next_algo, + loras, + configs=next_config, + ) + loras = {**loras, **new_lora_map} + for lora_name, lora in zip(new_lora_names, new_loras): + if lora_name not in loras and lora_name not in current_lora_map: + loras[lora_name] = lora + if lora_name not in lora_names: + lora_names.append(lora_name) + continue + + if name: + lora_name = prefix + "." + name + else: + lora_name = prefix + + if f"{self.LORA_PREFIX}_." in lora_name: + lora_name = lora_name.replace( + f"{self.LORA_PREFIX}_.", + f"{self.LORA_PREFIX}.", + ) + + lora_name = lora_name.replace(".", "_") + if lora_name in loras: + continue + + lora = create_single_module(lora_name, module, algo, **configs) + if lora is not None: + loras[lora_name] = lora + lora_names.append(lora_name) + return [loras[lora_name] for lora_name in lora_names], lora_names, loras + + # create module instances + def create_modules( + prefix, + root_module: torch.nn.Module, + target_replace_modules, + target_replace_names=[], + target_exclude_names=[], + ) -> List: + logger.info("Create LyCORIS Module") + loras = [] + lora_map = {} + next_config = {} + for name, module in root_module.named_modules(): + if name in target_exclude_names or any( + self.match_fn(t, name) for t in target_exclude_names + ): + continue + + module_name = module.__class__.__name__ + if module_name in target_replace_modules and not any( + self.match_fn(t, name) for t in target_replace_names + ): + if module_name in self.MODULE_ALGO_MAP: + next_config = self.MODULE_ALGO_MAP[module_name] + algo = next_config.get("algo", network_module) + else: + algo = network_module + + lora_lst, _, _lora_map = create_modules_( + f"{prefix}_{name}", + module, + algo, + lora_map, + configs=next_config, + ) + lora_map = {**lora_map, **_lora_map} + loras.extend(lora_lst) + next_config = {} + elif name in target_replace_names or any( + self.match_fn(t, name) for t in target_replace_names + ): + conf_from_name = self.find_conf_for_name(name) + if conf_from_name is not None: + next_config = conf_from_name + algo = next_config.get("algo", network_module) + elif module_name in self.MODULE_ALGO_MAP: + next_config = self.MODULE_ALGO_MAP[module_name] + algo = next_config.get("algo", network_module) + else: + algo = network_module + lora_name = prefix + "." + name + lora_name = lora_name.replace(".", "_") + + if lora_name in lora_map: + continue + + lora = create_single_module(lora_name, module, algo, **next_config) + next_config = {} + if lora is not None: + lora_map[lora.lora_name] = lora + loras.append(lora) + return loras + + self.loras = create_modules( + LycorisNetwork.LORA_PREFIX, + module, + list( + set( + [ + *LycorisNetwork.TARGET_REPLACE_MODULE, + *LycorisNetwork.MODULE_ALGO_MAP.keys(), + ] + ) + ), + list( + set( + [ + *LycorisNetwork.TARGET_REPLACE_NAME, + *LycorisNetwork.NAME_ALGO_MAP.keys(), + ] + ) + ), + target_exclude_names=LycorisNetwork.TARGET_EXCLUDE_NAME, + ) + logger.info(f"create LyCORIS: {len(self.loras)} modules.") + + algo_table = {} + for lora in self.loras: + algo_table[lora.__class__.__name__] = ( + algo_table.get(lora.__class__.__name__, 0) + 1 + ) + logger.info(f"module type table: {algo_table}") + + # Assertion to ensure we have not accidentally wrapped some layers + # multiple times. + names = set() + for lora in self.loras: + assert ( + lora.lora_name not in names + ), f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def match_fn(self, pattern: str, name: str) -> bool: + if self.USE_FNMATCH: + return fnmatch.fnmatch(name, pattern) + return bool(re.match(pattern, name)) + + def find_conf_for_name( + self, + name: str, + ) -> dict[str, Any]: + if name in self.NAME_ALGO_MAP.keys(): + return self.NAME_ALGO_MAP[name] + + for key, value in self.NAME_ALGO_MAP.items(): + if self.match_fn(key, name): + return value + + return None + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.loras: + lora.multiplier = self.multiplier + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + self.weights_sd = load_file(file) + else: + self.weights_sd = torch.load(file, map_location="cpu") + missing, unexpected = self.load_state_dict(self.weights_sd, strict=False) + state = {} + if missing: + state["missing keys"] = missing + if unexpected: + state["unexpected keys"] = unexpected + return state + + def apply_to(self): + """ + Register to modules to the subclass so that torch sees them. + """ + for lora in self.loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + if self.weights_sd: + # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros) + info = self.load_state_dict(self.weights_sd, False) + logger.info(f"weights are loaded: {info}") + + def is_mergeable(self): + return True + + def restore(self): + for lora in self.loras: + lora.restore() + + def merge_to(self, weight=1.0): + for lora in self.loras: + lora.merge_to(weight) + + def apply_max_norm_regularization(self, max_norm_value, device): + key_scaled = 0 + norms = [] + for module in self.loras: + scaled, norm = module.apply_max_norm(max_norm_value, device) + if scaled is None: + continue + norms.append(norm) + key_scaled += scaled + + if key_scaled == 0: + return key_scaled, 0, 0 + + return key_scaled, sum(norms) / len(norms), max(norms) + + def enable_gradient_checkpointing(self): + # not supported + def make_ckpt(module): + if isinstance(module, torch.nn.Module): + module.grad_ckpt = True + + self.apply(make_ckpt) + pass + + def prepare_optimizer_params(self, lr): + def enumerate_params(loras): + params = [] + for lora in loras: + params.extend(lora.parameters()) + return params + + self.requires_grad_(True) + all_params = [] + + param_data = {"params": enumerate_params(self.loras)} + if lr is not None: + param_data["lr"] = lr + all_params.append(param_data) + return all_params + + def prepare_grad_etc(self, *args): + self.requires_grad_(True) + + def on_epoch_start(self, *args): + self.train() + + def get_trainable_params(self, *args): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) diff --git a/models.yaml b/models.yaml new file mode 100644 index 0000000000000000000000000000000000000000..81b9992c30cf38dfec566c98a839142b99036da9 --- /dev/null +++ b/models.yaml @@ -0,0 +1,27 @@ +# Add your own model here +# : +# repo: +# base: +# license: +# license_name: +# license_link: +# file: +flux-dev: + repo: cocktailpeanut/xulf-dev + base: black-forest-labs/FLUX.1-dev + license: other + license_name: flux-1-dev-non-commercial-license + license_link: https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md + file: flux1-dev.sft +flux-schnell: + repo: black-forest-labs/FLUX.1-schnell + base: black-forest-labs/FLUX.1-schnell + license: apache-2.0 + file: flux1-schnell.safetensors +bdsqlsz/flux1-dev2pro-single: + repo: bdsqlsz/flux1-dev2pro-single + base: black-forest-labs/FLUX.1-dev + license: other + license_name: flux-1-dev-non-commercial-license + license_link: https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md + file: flux1-dev2pro.safetensors diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..d432181baaa3f89b2679a5780e03be18d6be1801 --- /dev/null +++ b/networks/check_lora_weights.py @@ -0,0 +1,48 @@ +import argparse +import os +import torch +from safetensors.torch import load_file +from .utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) + +def main(file): + logger.info(f"loading: {file}") + if os.path.splitext(file)[1] == ".safetensors": + sd = load_file(file) + else: + sd = torch.load(file, map_location="cpu") + + values = [] + + keys = list(sd.keys()) + for key in keys: + if "lora_up" in key or "lora_down" in key: + values.append((key, sd[key])) + print(f"number of LoRA modules: {len(values)}") + + if args.show_all_keys: + for key in [k for k in keys if k not in values]: + values.append((key, sd[key])) + print(f"number of all modules: {len(values)}") + + for key, value in values: + value = value.to(torch.float32) + print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル") + parser.add_argument("-s", "--show_all_keys", action="store_true", help="show all keys / 全てのキーを表示する") + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + + main(args.file) diff --git a/networks/lora.py b/networks/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..2e6ff1ddfd388c8244cfa4d457f1464407f3b18d --- /dev/null +++ b/networks/lora.py @@ -0,0 +1,1403 @@ +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +from transformers import CLIPTextModel +import numpy as np +import torch +import re +from ..library.utils import setup_logging +from ..library.sdxl_original_unet import SdxlUNet2DConditionModel + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + # if limit_rank: + # self.lora_dim = min(lora_dim, in_dim, out_dim) + # if self.lora_dim != lora_dim: + # logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") + # else: + self.lora_dim = lora_dim + + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x): + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + lx = self.lora_down(x) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + return org_forwarded + lx * self.multiplier * scale + + +class LoRAInfModule(LoRAModule): + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) + + self.org_module_ref = [org_module] # 後から参照できるように + self.enabled = True + + # check regional or not by lora_name + self.text_encoder = False + if lora_name.startswith("lora_te_"): + self.regional = False + self.use_sub_prompt = True + self.text_encoder = True + elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name: + self.regional = False + self.use_sub_prompt = True + elif "time_emb" in lora_name: + self.regional = False + self.use_sub_prompt = False + else: + self.regional = True + self.use_sub_prompt = False + + self.network: LoRANetwork = None + + def set_network(self, network): + self.network = network + + # freezeしてマージする + def merge_to(self, sd, dtype, device): + # get up/down weight + up_weight = sd["lora_up.weight"].to(torch.float).to(device) + down_weight = sd["lora_down.weight"].to(torch.float).to(device) + + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"].to(torch.float) + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + + # 復元できるマージのため、このモジュールのweightを返す + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + # get up/down weight from module + up_weight = self.lora_up.weight.to(torch.float) + down_weight = self.lora_down.weight.to(torch.float) + + # pre-calculated weight + if len(down_weight.size()) == 2: + # linear + weight = self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = self.multiplier * conved * self.scale + + return weight + + def set_region(self, region): + self.region = region + self.region_mask = None + + def default_forward(self, x): + # logger.info(f"default_forward {self.lora_name} {x.size()}") + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + + def forward(self, x): + if not self.enabled: + return self.org_forward(x) + + if self.network is None or self.network.sub_prompt_index is None: + return self.default_forward(x) + if not self.regional and not self.use_sub_prompt: + return self.default_forward(x) + + if self.regional: + return self.regional_forward(x) + else: + return self.sub_prompt_forward(x) + + def get_mask_for_x(self, x): + # calculate size from shape of x + if len(x.size()) == 4: + h, w = x.size()[2:4] + area = h * w + else: + area = x.size()[1] + + mask = self.network.mask_dic.get(area, None) + if mask is None or len(x.size()) == 2: + # emb_layers in SDXL doesn't have mask + # if "emb" not in self.lora_name: + # print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}") + mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1) + return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts + if len(x.size()) == 3: + mask = torch.reshape(mask, (1, -1, 1)) + return mask + + def regional_forward(self, x): + if "attn2_to_out" in self.lora_name: + return self.to_out_forward(x) + + if self.network.mask_dic is None: # sub_prompt_index >= 3 + return self.default_forward(x) + + # apply mask for LoRA result + lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale + mask = self.get_mask_for_x(lx) + # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) + # if mask.ndim > lx.ndim: # in some resolution, lx is 2d and mask is 3d (the reason is not checked) + # mask = mask.squeeze(-1) + lx = lx * mask + + x = self.org_forward(x) + x = x + lx + + if "attn2_to_q" in self.lora_name and self.network.is_last_network: + x = self.postp_to_q(x) + + return x + + def postp_to_q(self, x): + # repeat x to num_sub_prompts + has_real_uncond = x.size()[0] // self.network.batch_size == 3 + qc = self.network.batch_size # uncond + qc += self.network.batch_size * self.network.num_sub_prompts # cond + if has_real_uncond: + qc += self.network.batch_size # real_uncond + + query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype) + query[: self.network.batch_size] = x[: self.network.batch_size] + + for i in range(self.network.batch_size): + qi = self.network.batch_size + i * self.network.num_sub_prompts + query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i] + + if has_real_uncond: + query[-self.network.batch_size :] = x[-self.network.batch_size :] + + # logger.info(f"postp_to_q {self.lora_name} {x.size()} {query.size()} {self.network.num_sub_prompts}") + return query + + def sub_prompt_forward(self, x): + if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA + return self.org_forward(x) + + emb_idx = self.network.sub_prompt_index + if not self.text_encoder: + emb_idx += self.network.batch_size + + # apply sub prompt of X + lx = x[emb_idx :: self.network.num_sub_prompts] + lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale + + # logger.info(f"sub_prompt_forward {self.lora_name} {x.size()} {lx.size()} {emb_idx}") + + x = self.org_forward(x) + x[emb_idx :: self.network.num_sub_prompts] += lx + + return x + + def to_out_forward(self, x): + # logger.info(f"to_out_forward {self.lora_name} {x.size()} {self.network.is_last_network}") + + if self.network.is_last_network: + masks = [None] * self.network.num_sub_prompts + self.network.shared[self.lora_name] = (None, masks) + else: + lx, masks = self.network.shared[self.lora_name] + + # call own LoRA + x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts] + lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale + + if self.network.is_last_network: + lx = torch.zeros( + (self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype + ) + self.network.shared[self.lora_name] = (lx, masks) + + # logger.info(f"to_out_forward {lx.size()} {lx1.size()} {self.network.sub_prompt_index} {self.network.num_sub_prompts}") + lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1 + masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1) + + # if not last network, return x and masks + x = self.org_forward(x) + if not self.network.is_last_network: + return x + + lx, masks = self.network.shared.pop(self.lora_name) + + # if last network, combine separated x with mask weighted sum + has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2 + + out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype) + out[: self.network.batch_size] = x[: self.network.batch_size] # uncond + if has_real_uncond: + out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond + + # logger.info(f"to_out_forward {self.lora_name} {self.network.sub_prompt_index} {self.network.num_sub_prompts}") + # if num_sub_prompts > num of LoRAs, fill with zero + for i in range(len(masks)): + if masks[i] is None: + masks[i] = torch.zeros_like(masks[0]) + + mask = torch.cat(masks) + mask_sum = torch.sum(mask, dim=0) + 1e-4 + for i in range(self.network.batch_size): + # 1枚の画像ごとに処理する + lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts] + lx1 = lx1 * mask + lx1 = torch.sum(lx1, dim=0) + + xi = self.network.batch_size + i * self.network.num_sub_prompts + x1 = x[xi : xi + self.network.num_sub_prompts] + x1 = x1 * mask + x1 = torch.sum(x1, dim=0) + x1 = x1 / mask_sum + + x1 = x1 + lx1 + out[self.network.batch_size + i] = x1 + + # logger.info(f"to_out_forward {x.size()} {out.size()} {has_real_uncond}") + return out + + +def parse_block_lr_kwargs(is_sdxl: bool, nw_kwargs: Dict) -> Optional[List[float]]: + down_lr_weight = nw_kwargs.get("down_lr_weight", None) + mid_lr_weight = nw_kwargs.get("mid_lr_weight", None) + up_lr_weight = nw_kwargs.get("up_lr_weight", None) + + # 以上のいずれにも設定がない場合は無効としてNoneを返す + if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None: + return None + + # extract learning rate weight for each block + if down_lr_weight is not None: + # if some parameters are not set, use zero + if "," in down_lr_weight: + down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")] + + if mid_lr_weight is not None: + mid_lr_weight = [(float(s) if s else 0.0) for s in mid_lr_weight.split(",")] + + if up_lr_weight is not None: + if "," in up_lr_weight: + up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")] + + return get_block_lr_weight( + is_sdxl, down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0)) + ) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: AutoencoderKL, + text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], + unet, + neuron_dropout: Optional[float] = None, + **kwargs, +): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel) + + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # block dim/alpha/lr + block_dims = kwargs.get("block_dims", None) + block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs) + + # 以上のいずれかに指定があればblockごとのdim(rank)を有効にする + if block_dims is not None or block_lr_weight is not None: + block_alphas = kwargs.get("block_alphas", None) + conv_block_dims = kwargs.get("conv_block_dims", None) + conv_block_alphas = kwargs.get("conv_block_alphas", None) + + block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas( + is_sdxl, block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha + ) + + # remove block dim/alpha without learning rate + block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas( + is_sdxl, block_dims, block_alphas, conv_block_dims, conv_block_alphas, block_lr_weight + ) + + else: + block_alphas = None + conv_block_dims = None + conv_block_alphas = None + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # すごく引数が多いな ( ^ω^)・・・ + network = LoRANetwork( + text_encoder, + unet, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + block_dims=block_dims, + block_alphas=block_alphas, + conv_block_dims=conv_block_dims, + conv_block_alphas=conv_block_alphas, + varbose=True, + is_sdxl=is_sdxl, + ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + if block_lr_weight is not None: + network.set_block_lr_weight(block_lr_weight) + + return network + + +# このメソッドは外部から呼び出される可能性を考慮しておく +# network_dim, network_alpha にはデフォルト値が入っている。 +# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている +# conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている +def get_block_dims_and_alphas( + is_sdxl, block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha +): + if not is_sdxl: + num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS + else: + # 1+9+3+9+1=23, no LoRA for emb_layers (0) + num_total_blocks = 1 + LoRANetwork.SDXL_NUM_OF_BLOCKS * 2 + LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + 1 + + def parse_ints(s): + return [int(i) for i in s.split(",")] + + def parse_floats(s): + return [float(i) for i in s.split(",")] + + # block_dimsとblock_alphasをパースする。必ず値が入る + if block_dims is not None: + block_dims = parse_ints(block_dims) + assert len(block_dims) == num_total_blocks, ( + f"block_dims must have {num_total_blocks} elements but {len(block_dims)} elements are given" + + f" / block_dimsは{num_total_blocks}個指定してください(指定された個数: {len(block_dims)})" + ) + else: + logger.warning( + f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります" + ) + block_dims = [network_dim] * num_total_blocks + + if block_alphas is not None: + block_alphas = parse_floats(block_alphas) + assert ( + len(block_alphas) == num_total_blocks + ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください" + else: + logger.warning( + f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります" + ) + block_alphas = [network_alpha] * num_total_blocks + + # conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う + if conv_block_dims is not None: + conv_block_dims = parse_ints(conv_block_dims) + assert ( + len(conv_block_dims) == num_total_blocks + ), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください" + + if conv_block_alphas is not None: + conv_block_alphas = parse_floats(conv_block_alphas) + assert ( + len(conv_block_alphas) == num_total_blocks + ), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください" + else: + if conv_alpha is None: + conv_alpha = 1.0 + logger.warning( + f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります" + ) + conv_block_alphas = [conv_alpha] * num_total_blocks + else: + if conv_dim is not None: + logger.warning( + f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります" + ) + conv_block_dims = [conv_dim] * num_total_blocks + conv_block_alphas = [conv_alpha] * num_total_blocks + else: + conv_block_dims = None + conv_block_alphas = None + + return block_dims, block_alphas, conv_block_dims, conv_block_alphas + + +# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出せるようにclass外に出しておく +# 戻り値は block ごとの倍率のリスト +def get_block_lr_weight( + is_sdxl, + down_lr_weight: Union[str, List[float]], + mid_lr_weight: List[float], + up_lr_weight: Union[str, List[float]], + zero_threshold: float, +) -> Optional[List[float]]: + # パラメータ未指定時は何もせず、今までと同じ動作とする + if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None: + return None + + if not is_sdxl: + max_len_for_down_or_up = LoRANetwork.NUM_OF_BLOCKS + max_len_for_mid = LoRANetwork.NUM_OF_MID_BLOCKS + else: + max_len_for_down_or_up = LoRANetwork.SDXL_NUM_OF_BLOCKS + max_len_for_mid = LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + + def get_list(name_with_suffix) -> List[float]: + import math + + tokens = name_with_suffix.split("+") + name = tokens[0] + base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0 + + if name == "cosine": + return [ + math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr + for i in reversed(range(max_len_for_down_or_up)) + ] + elif name == "sine": + return [math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr for i in range(max_len_for_down_or_up)] + elif name == "linear": + return [i / (max_len_for_down_or_up - 1) + base_lr for i in range(max_len_for_down_or_up)] + elif name == "reverse_linear": + return [i / (max_len_for_down_or_up - 1) + base_lr for i in reversed(range(max_len_for_down_or_up))] + elif name == "zeros": + return [0.0 + base_lr] * max_len_for_down_or_up + else: + logger.error( + "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" + % (name) + ) + return None + + if type(down_lr_weight) == str: + down_lr_weight = get_list(down_lr_weight) + if type(up_lr_weight) == str: + up_lr_weight = get_list(up_lr_weight) + + if (up_lr_weight != None and len(up_lr_weight) > max_len_for_down_or_up) or ( + down_lr_weight != None and len(down_lr_weight) > max_len_for_down_or_up + ): + logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len_for_down_or_up) + logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len_for_down_or_up) + up_lr_weight = up_lr_weight[:max_len_for_down_or_up] + down_lr_weight = down_lr_weight[:max_len_for_down_or_up] + + if mid_lr_weight != None and len(mid_lr_weight) > max_len_for_mid: + logger.warning("mid_weight is too long. Parameters after %d-th are ignored." % max_len_for_mid) + logger.warning("mid_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len_for_mid) + mid_lr_weight = mid_lr_weight[:max_len_for_mid] + + if (up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up) or ( + down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up + ): + logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len_for_down_or_up) + logger.warning( + "down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len_for_down_or_up + ) + + if down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up: + down_lr_weight = down_lr_weight + [1.0] * (max_len_for_down_or_up - len(down_lr_weight)) + if up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up: + up_lr_weight = up_lr_weight + [1.0] * (max_len_for_down_or_up - len(up_lr_weight)) + + if mid_lr_weight != None and len(mid_lr_weight) < max_len_for_mid: + logger.warning("mid_weight is too short. Parameters after %d-th are filled with 1." % max_len_for_mid) + logger.warning("mid_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len_for_mid) + mid_lr_weight = mid_lr_weight + [1.0] * (max_len_for_mid - len(mid_lr_weight)) + + if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): + logger.info("apply block learning rate / 階層別学習率を適用します。") + if down_lr_weight != None: + down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] + logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}") + else: + down_lr_weight = [1.0] * max_len_for_down_or_up + logger.info("down_lr_weight: all 1.0, すべて1.0") + + if mid_lr_weight != None: + mid_lr_weight = [w if w > zero_threshold else 0 for w in mid_lr_weight] + logger.info(f"mid_lr_weight: {mid_lr_weight}") + else: + mid_lr_weight = [1.0] * max_len_for_mid + logger.info("mid_lr_weight: all 1.0, すべて1.0") + + if up_lr_weight != None: + up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] + logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}") + else: + up_lr_weight = [1.0] * max_len_for_down_or_up + logger.info("up_lr_weight: all 1.0, すべて1.0") + + lr_weight = down_lr_weight + mid_lr_weight + up_lr_weight + + if is_sdxl: + lr_weight = [1.0] + lr_weight + [1.0] # add 1.0 for emb_layers and out + + assert (not is_sdxl and len(lr_weight) == LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS) or ( + is_sdxl and len(lr_weight) == 1 + LoRANetwork.SDXL_NUM_OF_BLOCKS * 2 + LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + 1 + ), f"lr_weight length is invalid: {len(lr_weight)}" + + return lr_weight + + +# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく +def remove_block_dims_and_alphas( + is_sdxl, block_dims, block_alphas, conv_block_dims, conv_block_alphas, block_lr_weight: Optional[List[float]] +): + if block_lr_weight is not None: + for i, lr in enumerate(block_lr_weight): + if lr == 0: + block_dims[i] = 0 + if conv_block_dims is not None: + conv_block_dims[i] = 0 + return block_dims, block_alphas, conv_block_dims, conv_block_alphas + + +# 外部から呼び出す可能性を考慮しておく +def get_block_index(lora_name: str, is_sdxl: bool = False) -> int: + block_idx = -1 # invalid lora name + if not is_sdxl: + m = RE_UPDOWN.search(lora_name) + if m: + g = m.groups() + i = int(g[1]) + j = int(g[3]) + if g[2] == "resnets": + idx = 3 * i + j + elif g[2] == "attentions": + idx = 3 * i + j + elif g[2] == "upsamplers" or g[2] == "downsamplers": + idx = 3 * i + 2 + + if g[0] == "down": + block_idx = 1 + idx # 0に該当するLoRAは存在しない + elif g[0] == "up": + block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx + elif "mid_block_" in lora_name: + block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12 + else: + # copy from sdxl_train + if lora_name.startswith("lora_unet_"): + name = lora_name[len("lora_unet_") :] + if name.startswith("time_embed_") or name.startswith("label_emb_"): # No LoRA + block_idx = 0 # 0 + elif name.startswith("input_blocks_"): # 1-9 + block_idx = 1 + int(name.split("_")[2]) + elif name.startswith("middle_block_"): # 10-12 + block_idx = 10 + int(name.split("_")[2]) + elif name.startswith("output_blocks_"): # 13-21 + block_idx = 13 + int(name.split("_")[2]) + elif name.startswith("out_"): # 22, out, no LoRA + block_idx = 22 + + return block_idx + + +def convert_diffusers_to_sai_if_needed(weights_sd): + # only supports U-Net LoRA modules + + found_up_down_blocks = False + for k in list(weights_sd.keys()): + if "down_blocks" in k: + found_up_down_blocks = True + break + if "up_blocks" in k: + found_up_down_blocks = True + break + if not found_up_down_blocks: + return + + from ..library.sdxl_model_util import make_unet_conversion_map + + unet_conversion_map = make_unet_conversion_map() + unet_conversion_map = {hf.replace(".", "_")[:-1]: sd.replace(".", "_")[:-1] for sd, hf in unet_conversion_map} + + # # add extra conversion + # unet_conversion_map["up_blocks_1_upsamplers_0"] = "lora_unet_output_blocks_2_2_conv" + + logger.info(f"Converting LoRA keys from Diffusers to SAI") + lora_unet_prefix = "lora_unet_" + for k in list(weights_sd.keys()): + if not k.startswith(lora_unet_prefix): + continue + + unet_module_name = k[len(lora_unet_prefix) :].split(".")[0] + + # search for conversion: this is slow because the algorithm is O(n^2), but the number of keys is small + for hf_module_name, sd_module_name in unet_conversion_map.items(): + if hf_module_name in unet_module_name: + new_key = ( + lora_unet_prefix + + unet_module_name.replace(hf_module_name, sd_module_name) + + k[len(lora_unet_prefix) + len(unet_module_name) :] + ) + weights_sd[new_key] = weights_sd.pop(k) + found = True + break + + if not found: + logger.warning(f"Key {k} is not found in unet_conversion_map") + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel) + + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # if keys are Diffusers based, convert to SAI based + convert_diffusers_to_sai_if_needed(weights_sd) + + # get dim/alpha mapping + modules_dim = {} + modules_alpha = {} + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # logger.info(lora_name, value.size(), dim) + + # support old LoRA without alpha + for key in modules_dim.keys(): + if key not in modules_alpha: + modules_alpha[key] = modules_dim[key] + + module_class = LoRAInfModule if for_inference else LoRAModule + + network = LoRANetwork( + text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class + ) + + # block lr + block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs) + if block_lr_weight is not None: + network.set_block_lr_weight(block_lr_weight) + + return network, weights_sd + + +class LoRANetwork(torch.nn.Module): + NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 + NUM_OF_MID_BLOCKS = 1 + SDXL_NUM_OF_BLOCKS = 9 # SDXLのモデルでのinput/outputの層の数 total=1(base) 9(input) + 3(mid) + 9(output) + 1(out) = 23 + SDXL_NUM_OF_MID_BLOCKS = 3 + + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + LORA_PREFIX_UNET = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + + # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER + LORA_PREFIX_TEXT_ENCODER1 = "lora_te1" + LORA_PREFIX_TEXT_ENCODER2 = "lora_te2" + + def __init__( + self, + text_encoder: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + block_dims: Optional[List[int]] = None, + block_alphas: Optional[List[float]] = None, + conv_block_dims: Optional[List[int]] = None, + conv_block_alphas: Optional[List[float]] = None, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + module_class: Type[object] = LoRAModule, + varbose: Optional[bool] = False, + is_sdxl: Optional[bool] = False, + ) -> None: + """ + LoRA network: すごく引数が多いが、パターンは以下の通り + 1. lora_dimとalphaを指定 + 2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定 + 3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない + 4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する + 5. modules_dimとmodules_alphaを指定 (推論用) + """ + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + + if modules_dim is not None: + logger.info(f"create LoRA network from weights") + elif block_dims is not None: + logger.info(f"create LoRA network from block_dims") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + logger.info(f"block_dims: {block_dims}") + logger.info(f"block_alphas: {block_alphas}") + if conv_block_dims is not None: + logger.info(f"conv_block_dims: {conv_block_dims}") + logger.info(f"conv_block_alphas: {conv_block_alphas}") + else: + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + if self.conv_lora_dim is not None: + logger.info( + f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + ) + + # create module instances + def create_modules( + is_unet: bool, + text_encoder_idx: Optional[int], # None, 1, 2 + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_UNET + if is_unet + else ( + self.LORA_PREFIX_TEXT_ENCODER + if text_encoder_idx is None + else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2) + ) + ) + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + + dim = None + alpha = None + + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + elif is_unet and block_dims is not None: + # U-Netでblock_dims指定あり + block_idx = get_block_index(lora_name, is_sdxl) + if is_linear or is_conv2d_1x1: + dim = block_dims[block_idx] + alpha = block_alphas[block_idx] + elif conv_block_dims is not None: + dim = conv_block_dims[block_idx] + alpha = conv_block_alphas[block_idx] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None): + skipped.append(lora_name) + continue + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + ) + loras.append(lora) + return loras, skipped + + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + if len(text_encoders) > 1: + index = i + 1 + logger.info(f"create LoRA for Text Encoder {index}:") + else: + index = None + logger.info(f"create LoRA for Text Encoder:") + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None: + target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + + self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) + logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + + skipped = skipped_te + skipped_un + if varbose and len(skipped) > 0: + logger.warning( + f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + logger.info(f"\t{name}") + + self.block_lr_weight = None + self.block_lr = False + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoder, unet, weights_sd, dtype, device): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): + apply_unet = True + + if apply_text_encoder: + logger.info("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + logger.info(f"weights are merged") + + # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない + def set_block_lr_weight(self, block_lr_weight: Optional[List[float]]): + self.block_lr = True + self.block_lr_weight = block_lr_weight + + def get_lr_weight(self, block_idx: int) -> float: + if not self.block_lr or self.block_lr_weight is None: + return 1.0 + return self.block_lr_weight[block_idx] + + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + # TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?) + # if ( + # self.loraplus_lr_ratio is not None + # or self.loraplus_text_encoder_lr_ratio is not None + # or self.loraplus_unet_lr_ratio is not None + # ): + # assert ( + # optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower() + # ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません" + + self.requires_grad_(True) + + all_params = [] + lr_descriptions = [] + + def assemble_params(loras, lr, ratio): + param_groups = {"lora": {}, "plus": {}} + for lora in loras: + for name, param in lora.named_parameters(): + if ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + descriptions = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + + params.append(param_data) + descriptions.append("plus" if key == "plus" else "") + + return params, descriptions + + if self.text_encoder_loras: + params, descriptions = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions]) + + if self.unet_loras: + if self.block_lr: + is_sdxl = False + for lora in self.unet_loras: + if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name: + is_sdxl = True + break + + # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類 + block_idx_to_lora = {} + for lora in self.unet_loras: + idx = get_block_index(lora.lora_name, is_sdxl) + if idx not in block_idx_to_lora: + block_idx_to_lora[idx] = [] + block_idx_to_lora[idx].append(lora) + + # blockごとにパラメータを設定する + for idx, block_loras in block_idx_to_lora.items(): + params, descriptions = assemble_params( + block_loras, + (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx), + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions]) + + else: + params, descriptions = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) + + return all_params, lr_descriptions + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from ..library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + # mask is a tensor with values from 0 to 1 + def set_region(self, sub_prompt_index, is_last_network, mask): + if mask.max() == 0: + mask = torch.ones_like(mask) + + self.mask = mask + self.sub_prompt_index = sub_prompt_index + self.is_last_network = is_last_network + + for lora in self.text_encoder_loras + self.unet_loras: + lora.set_network(self) + + def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared, ds_ratio=None): + self.batch_size = batch_size + self.num_sub_prompts = num_sub_prompts + self.current_size = (height, width) + self.shared = shared + + # create masks + mask = self.mask + mask_dic = {} + mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w + ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight + dtype = ref_weight.dtype + device = ref_weight.device + + def resize_add(mh, mw): + # logger.info(mh, mw, mh * mw) + m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16 + m = m.to(device, dtype=dtype) + mask_dic[mh * mw] = m + + h = height // 8 + w = width // 8 + for _ in range(4): + resize_add(h, w) + if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2 + resize_add(h + h % 2, w + w % 2) + + # deep shrink + if ds_ratio is not None: + hd = int(h * ds_ratio) + wd = int(w * ds_ratio) + resize_add(hd, wd) + + h = (h + 1) // 2 + w = (w + 1) // 2 + + self.mask_dic = mask_dic + + def backup_weights(self): + # 重みのバックアップを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False + + def apply_max_norm_regularization(self, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + + state_dict = self.state_dict() + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms) diff --git a/networks/lora_flux.py b/networks/lora_flux.py new file mode 100644 index 0000000000000000000000000000000000000000..5fb863332f521af45ebd40eda4c19b619987131b --- /dev/null +++ b/networks/lora_flux.py @@ -0,0 +1,1032 @@ +# temporary minimum implementation of LoRA +# FLUX doesn't have Conv2d, so we ignore it +# TODO commonize with the original implementation + +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +from transformers import CLIPTextModel +import numpy as np +import torch +import re +#from ..library.utils import setup_logging + +#setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + split_dims: Optional[List[int]] = None, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + self.lora_dim = lora_dim + self.split_dims = split_dims + + if split_dims is None: + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + else: + # conv2d not supported + assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim" + assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear" + # print(f"split_dims: {split_dims}") + self.lora_down = torch.nn.ModuleList( + [torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))] + ) + self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims]) + for lora_down in self.lora_down: + torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5)) + for lora_up in self.lora_up: + torch.nn.init.zeros_(lora_up.weight) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x): + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + if self.split_dims is None: + lx = self.lora_down(x) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + return org_forwarded + lx * self.multiplier * scale + else: + lxs = [lora_down(x) for lora_down in self.lora_down] + + # normal dropout + if self.dropout is not None and self.training: + lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs] + + # rank dropout + if self.rank_dropout is not None and self.training: + masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs] + for i in range(len(lxs)): + if len(lx.size()) == 3: + masks[i] = masks[i].unsqueeze(1) + elif len(lx.size()) == 4: + masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1) + lxs[i] = lxs[i] * masks[i] + + # scaling for rank dropout: treat as if the rank is changed + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + + return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale + + +class LoRAInfModule(LoRAModule): + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) + + self.org_module_ref = [org_module] # 後から参照できるように + self.enabled = True + self.network: LoRANetwork = None + + def set_network(self, network): + self.network = network + + # freezeしてマージする + def merge_to(self, sd, dtype, device): + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"] + org_dtype = weight.dtype + org_device = weight.device + weight = weight.to(torch.float) # calc in float + + if dtype is None: + dtype = org_dtype + if device is None: + device = org_device + + if self.split_dims is None: + # get up/down weight + down_weight = sd["lora_down.weight"].to(torch.float).to(device) + up_weight = sd["lora_up.weight"].to(torch.float).to(device) + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + else: + # split_dims + total_dims = sum(self.split_dims) + for i in range(len(self.split_dims)): + # get up/down weight + down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim) + up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank) + + # pad up_weight -> (total_dims, rank) + padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float) + padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight + + # merge weight + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + + # 復元できるマージのため、このモジュールのweightを返す + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + # get up/down weight from module + up_weight = self.lora_up.weight.to(torch.float) + down_weight = self.lora_down.weight.to(torch.float) + + # pre-calculated weight + if len(down_weight.size()) == 2: + # linear + weight = self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = self.multiplier * conved * self.scale + + return weight + + def set_region(self, region): + self.region = region + self.region_mask = None + + def default_forward(self, x): + # logger.info(f"default_forward {self.lora_name} {x.size()}") + if self.split_dims is None: + lx = self.lora_down(x) + lx = self.lora_up(lx) + return self.org_forward(x) + lx * self.multiplier * self.scale + else: + lxs = [lora_down(x) for lora_down in self.lora_down] + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale + + def forward(self, x): + if not self.enabled: + return self.org_forward(x) + return self.default_forward(x) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + ae: AutoencoderKL, + text_encoders: List[CLIPTextModel], + flux, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # single or double blocks + train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double" + if train_blocks is not None: + assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}" + + only_if_contains = kwargs.get("only_if_contains", None) + if only_if_contains is not None: + only_if_contains = [word.strip() for word in only_if_contains.split(',')] + + # split qkv + split_qkv = kwargs.get("split_qkv", False) + if split_qkv is not None: + split_qkv = True if split_qkv == "True" else False + + # train T5XXL + train_t5xxl = kwargs.get("train_t5xxl", False) + if train_t5xxl is not None: + train_t5xxl = True if train_t5xxl == "True" else False + + # すごく引数が多いな ( ^ω^)・・・ + network = LoRANetwork( + text_encoders, + flux, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + train_blocks=train_blocks, + split_qkv=split_qkv, + train_t5xxl=train_t5xxl, + varbose=True, + only_if_contains=only_if_contains + ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weights_sd=None, for_inference=False, **kwargs): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping, and train t5xxl + modules_dim = {} + modules_alpha = {} + train_t5xxl = None + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # logger.info(lora_name, value.size(), dim) + + if train_t5xxl is None or train_t5xxl is False: + train_t5xxl = "lora_te3" in lora_name + + if train_t5xxl is None: + train_t5xxl = False + + # # split qkv + # double_qkv_rank = None + # single_qkv_rank = None + # rank = None + # for lora_name, dim in modules_dim.items(): + # if "double" in lora_name and "qkv" in lora_name: + # double_qkv_rank = dim + # elif "single" in lora_name and "linear1" in lora_name: + # single_qkv_rank = dim + # elif rank is None: + # rank = dim + # if double_qkv_rank is not None and single_qkv_rank is not None and rank is not None: + # break + # split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or ( + # single_qkv_rank is not None and single_qkv_rank != rank + # ) + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + + module_class = LoRAInfModule if for_inference else LoRAModule + + network = LoRANetwork( + text_encoders, + flux, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + split_qkv=split_qkv, + train_t5xxl=train_t5xxl, + ) + return network, weights_sd + + +class LoRANetwork(torch.nn.Module): + # FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] + FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"] + FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"] + LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible + LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" + LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible + + def __init__( + self, + text_encoders: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + module_class: Type[object] = LoRAModule, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + train_blocks: Optional[str] = None, + split_qkv: bool = False, + train_t5xxl: bool = False, + varbose: Optional[bool] = False, + only_if_contains: Optional[List[str]] = None, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.train_blocks = train_blocks if train_blocks is not None else "all" + self.split_qkv = split_qkv + self.train_t5xxl = train_t5xxl + + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + + self.only_if_contains = only_if_contains + + if modules_dim is not None: + logger.info(f"create LoRA network from weights") + else: + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + # if self.conv_lora_dim is not None: + # logger.info( + # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + # ) + if self.split_qkv: + logger.info(f"split qkv for LoRA") + if self.train_blocks is not None: + logger.info(f"train {self.train_blocks} blocks only") + if train_t5xxl: + logger.info(f"train T5XXL as well") + + #self.only_if_contains = ["lora_unet_single_blocks_20_linear2"] + + # create module instances + def create_modules( + is_flux: bool, text_encoder_idx: Optional[int], root_module: torch.nn.Module, target_replace_modules: List[str] + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_FLUX + if is_flux + else (self.LORA_PREFIX_TEXT_ENCODER_CLIP if text_encoder_idx == 0 else self.LORA_PREFIX_TEXT_ENCODER_T5) + ) + + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + #lora_unet_single_blocks_20_linear2 + + if "unet" in lora_name and (self.only_if_contains is not None and not any(word in lora_name for word in self.only_if_contains)): + continue + + dim = None + alpha = None + + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None): + skipped.append(lora_name) + continue + + # qkv split + split_dims = None + if is_flux and split_qkv: + if "double" in lora_name and "qkv" in lora_name: + split_dims = [3072] * 3 + elif "single" in lora_name and "linear1" in lora_name: + split_dims = [3072] * 3 + [12288] + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + split_dims=split_dims, + ) + loras.append(lora) + return loras, skipped + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + index = i + if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False + break + + logger.info(f"create LoRA for Text Encoder {index+1}:") + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.") + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + + # create LoRA for U-Net + if self.train_blocks == "all": + target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + LoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE + elif self.train_blocks == "single": + target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE + elif self.train_blocks == "double": + target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + + self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] + self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules) + logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.") + #print(self.unet_loras) + + skipped = skipped_te + skipped_un + if varbose and len(skipped) > 0: + logger.warning( + f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + logger.info(f"\t{name}") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def load_state_dict(self, state_dict, strict=True): + # override to convert original weight to split qkv + if not self.split_qkv: + return super().load_state_dict(state_dict, strict) + + # split qkv + for key in list(state_dict.keys()): + if "double" in key and "qkv" in key: + split_dims = [3072] * 3 + elif "single" in key and "linear1" in key: + split_dims = [3072] * 3 + [12288] + else: + continue + + weight = state_dict[key] + lora_name = key.split(".")[0] + if "lora_down" in key and "weight" in key: + # dense weight (rank*3, in_dim) + split_weight = torch.chunk(weight, len(split_dims), dim=0) + for i, split_w in enumerate(split_weight): + state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w + + del state_dict[key] + # print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}") + elif "lora_up" in key and "weight" in key: + # sparse weight (out_dim=sum(split_dims), rank*3) + rank = weight.size(1) // len(split_dims) + i = 0 + for j in range(len(split_dims)): + state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dims[j], j * rank : (j + 1) * rank] + i += split_dims[j] + del state_dict[key] + + # # check is sparse + # i = 0 + # is_zero = True + # for j in range(len(split_dims)): + # for k in range(len(split_dims)): + # if j == k: + # continue + # is_zero = is_zero and torch.all(weight[i : i + split_dims[j], k * rank : (k + 1) * rank] == 0) + # i += split_dims[j] + # if not is_zero: + # logger.warning(f"weight is not sparse: {key}") + # else: + # logger.info(f"weight is sparse: {key}") + + # print( + # f"split {key}: {weight.shape} to {[state_dict[k].shape for k in [f'{lora_name}.lora_up.{j}.weight' for j in range(len(split_dims))]]}" + # ) + + # alpha is unchanged + + return super().load_state_dict(state_dict, strict) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + if not self.split_qkv: + return super().state_dict(destination, prefix, keep_vars) + + # merge qkv + state_dict = super().state_dict(destination, prefix, keep_vars) + new_state_dict = {} + for key in list(state_dict.keys()): + if "double" in key and "qkv" in key: + split_dims = [3072] * 3 + elif "single" in key and "linear1" in key: + split_dims = [3072] * 3 + [12288] + else: + new_state_dict[key] = state_dict[key] + continue + + if key not in state_dict: + continue # already merged + + lora_name = key.split(".")[0] + + # (rank, in_dim) * 3 + down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))] + # (split dim, rank) * 3 + up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))] + + alpha = state_dict.pop(f"{lora_name}.alpha") + + # merge down weight + down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) + + # merge up weight (sum of split_dim, rank*3) + rank = up_weights[0].size(1) + up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) + i = 0 + for j in range(len(split_dims)): + up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j] + i += split_dims[j] + + new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight + new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight + new_state_dict[f"{lora_name}.alpha"] = alpha + + # print( + # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" + # ) + print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha") + + return new_state_dict + + def apply_to(self, text_encoders, flux, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoders, flux, weights_sd, dtype=None, device=None): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP) or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_FLUX): + apply_unet = True + + if apply_text_encoder: + logger.info("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + logger.info(f"weights are merged") + + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + + def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): + # make sure text_encoder_lr as list of two elements + # if float, use the same value for both text encoders + if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): + text_encoder_lr = [default_lr, default_lr] + elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int): + text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr)] + elif len(text_encoder_lr) == 1: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]] + + self.requires_grad_(True) + + all_params = [] + lr_descriptions = [] + + def assemble_params(loras, lr, loraplus_ratio): + param_groups = {"lora": {}, "plus": {}} + for lora in loras: + for name, param in lora.named_parameters(): + if loraplus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + descriptions = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * loraplus_ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + + params.append(param_data) + descriptions.append("plus" if key == "plus" else "") + + return params, descriptions + + if self.text_encoder_loras: + loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio + + # split text encoder loras for te1 and te3 + te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP)] + te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)] + if len(te1_loras) > 0: + logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}") + params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions]) + if len(te3_loras) > 0: + logger.info(f"Text Encoder 2 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[1]}") + params, descriptions = assemble_params(te3_loras, text_encoder_lr[1], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 2 " + (" " + d if d else "") for d in descriptions]) + + if self.unet_loras: + params, descriptions = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) + + return all_params, lr_descriptions + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False + + def apply_max_norm_regularization(self, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + + state_dict = self.state_dict() + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms) + +def precalculate_safetensors_hashes(tensors, metadata): + """Precalculate the model hashes needed by sd-webui-additional-networks to + save time on indexing the model later.""" + + import hashlib + import safetensors.torch + from io import BytesIO + + # Retain only training metadata for hash calculation + metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} + bytes = safetensors.torch.save(tensors, metadata) + b = BytesIO(bytes) + + def addnet_hash_legacy(b): + """Old model hash used by sd-webui-additional-networks for .safetensors format files""" + m = hashlib.sha256() + b.seek(0x100000) + m.update(b.read(0x10000)) + return m.hexdigest()[0:8] + + def addnet_hash_safetensors(b): + """New model hash used by sd-webui-additional-networks for .safetensors format files""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") + offset = n + 8 + b.seek(offset) + for chunk in iter(lambda: b.read(blksize), b""): + hash_sha256.update(chunk) + return hash_sha256.hexdigest() + + model_hash = addnet_hash_safetensors(b) + legacy_hash = addnet_hash_legacy(b) + return model_hash, legacy_hash \ No newline at end of file diff --git a/networks/lora_sd3.py b/networks/lora_sd3.py new file mode 100644 index 0000000000000000000000000000000000000000..c4513eb227859af49c56afe175be1514a5501a4f --- /dev/null +++ b/networks/lora_sd3.py @@ -0,0 +1,837 @@ +# temporary minimum implementation of LoRA +# SD3 doesn't have Conv2d, so we ignore it +# TODO commonize with the original/SD3/FLUX implementation + +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from transformers import CLIPTextModelWithProjection, T5EncoderModel +import torch +from ..library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +from .lora_flux import LoRAModule, LoRAInfModule +from ..library import sd3_models + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: sd3_models.SDVAE, + text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]], + mmdit, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # attn dim, mlp dim: only for DoubleStreamBlock. SingleStreamBlock is not supported because of combined qkv + context_attn_dim = kwargs.get("context_attn_dim", None) + context_mlp_dim = kwargs.get("context_mlp_dim", None) + context_mod_dim = kwargs.get("context_mod_dim", None) + x_attn_dim = kwargs.get("x_attn_dim", None) + x_mlp_dim = kwargs.get("x_mlp_dim", None) + x_mod_dim = kwargs.get("x_mod_dim", None) + if context_attn_dim is not None: + context_attn_dim = int(context_attn_dim) + if context_mlp_dim is not None: + context_mlp_dim = int(context_mlp_dim) + if context_mod_dim is not None: + context_mod_dim = int(context_mod_dim) + if x_attn_dim is not None: + x_attn_dim = int(x_attn_dim) + if x_mlp_dim is not None: + x_mlp_dim = int(x_mlp_dim) + if x_mod_dim is not None: + x_mod_dim = int(x_mod_dim) + type_dims = [context_attn_dim, context_mlp_dim, context_mod_dim, x_attn_dim, x_mlp_dim, x_mod_dim] + if all([d is None for d in type_dims]): + type_dims = None + + # emb_dims [context_embedder, t_embedder, x_embedder, y_embedder, final_mod, final_linear] + emb_dims = kwargs.get("emb_dims", None) + if emb_dims is not None: + emb_dims = emb_dims.strip() + if emb_dims.startswith("[") and emb_dims.endswith("]"): + emb_dims = emb_dims[1:-1] + emb_dims = [int(d) for d in emb_dims.split(",")] # is it better to use ast.literal_eval? + assert len(emb_dims) == 6, f"invalid emb_dims: {emb_dims}, must be 6 dimensions (context, t, x, y, final_mod, final_linear)" + + # double/single train blocks + def parse_block_selection(selection: str, total_blocks: int) -> List[bool]: + """ + Parse a block selection string and return a list of booleans. + + Args: + selection (str): A string specifying which blocks to select. + total_blocks (int): The total number of blocks available. + + Returns: + List[bool]: A list of booleans indicating which blocks are selected. + """ + if selection == "all": + return [True] * total_blocks + if selection == "none" or selection == "": + return [False] * total_blocks + + selected = [False] * total_blocks + ranges = selection.split(",") + + for r in ranges: + if "-" in r: + start, end = map(str.strip, r.split("-")) + start = int(start) + end = int(end) + assert 0 <= start < total_blocks, f"invalid start index: {start}" + assert 0 <= end < total_blocks, f"invalid end index: {end}" + assert start <= end, f"invalid range: {start}-{end}" + for i in range(start, end + 1): + selected[i] = True + else: + index = int(r) + assert 0 <= index < total_blocks, f"invalid index: {index}" + selected[index] = True + + return selected + + train_block_indices = kwargs.get("train_block_indices", None) + if train_block_indices is not None: + train_block_indices = parse_block_selection(train_block_indices, 999) # 999 is a dummy number + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # split qkv + split_qkv = kwargs.get("split_qkv", False) + if split_qkv is not None: + split_qkv = True if split_qkv == "True" else False + + # train T5XXL + train_t5xxl = kwargs.get("train_t5xxl", False) + if train_t5xxl is not None: + train_t5xxl = True if train_t5xxl == "True" else False + + # verbose + verbose = kwargs.get("verbose", False) + if verbose is not None: + verbose = True if verbose == "True" else False + + # すごく引数が多いな ( ^ω^)・・・ + network = LoRANetwork( + text_encoders, + mmdit, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + split_qkv=split_qkv, + train_t5xxl=train_t5xxl, + type_dims=type_dims, + emb_dims=emb_dims, + train_block_indices=train_block_indices, + verbose=verbose, + ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, ae, text_encoders, mmdit, weights_sd=None, for_inference=False, **kwargs): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping, and train t5xxl + modules_dim = {} + modules_alpha = {} + train_t5xxl = None + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # logger.info(lora_name, value.size(), dim) + + if train_t5xxl is None or train_t5xxl is False: + train_t5xxl = "lora_te3" in lora_name + + if train_t5xxl is None: + train_t5xxl = False + + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + + module_class = LoRAInfModule if for_inference else LoRAModule + + network = LoRANetwork( + text_encoders, + mmdit, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + split_qkv=split_qkv, + train_t5xxl=train_t5xxl, + ) + return network, weights_sd + + +class LoRANetwork(torch.nn.Module): + SD3_TARGET_REPLACE_MODULE = ["SingleDiTBlock"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"] + LORA_PREFIX_SD3 = "lora_unet" # make ComfyUI compatible + LORA_PREFIX_TEXT_ENCODER_CLIP_L = "lora_te1" + LORA_PREFIX_TEXT_ENCODER_CLIP_G = "lora_te2" + LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible + + def __init__( + self, + text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]], + unet: sd3_models.MMDiT, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + module_class: Type[object] = LoRAModule, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + split_qkv: bool = False, + train_t5xxl: bool = False, + type_dims: Optional[List[int]] = None, + emb_dims: Optional[List[int]] = None, + train_block_indices: Optional[List[bool]] = None, + verbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.split_qkv = split_qkv + self.train_t5xxl = train_t5xxl + + self.type_dims = type_dims + self.emb_dims = emb_dims + self.train_block_indices = train_block_indices + + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + + if modules_dim is not None: + logger.info(f"create LoRA network from weights") + self.emb_dims = [0] * 6 # create emb_dims + # verbose = True + else: + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + # if self.conv_lora_dim is not None: + # logger.info( + # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + # ) + + qkv_dim = 0 + if self.split_qkv: + logger.info(f"split qkv for LoRA") + qkv_dim = unet.joint_blocks[0].context_block.attn.qkv.weight.size(0) + if train_t5xxl: + logger.info(f"train T5XXL as well") + + # create module instances + def create_modules( + is_mmdit: bool, + text_encoder_idx: Optional[int], + root_module: torch.nn.Module, + target_replace_modules: List[str], + filter: Optional[str] = None, + default_dim: Optional[int] = None, + include_conv2d_if_filter: bool = False, + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_SD3 + if is_mmdit + else [self.LORA_PREFIX_TEXT_ENCODER_CLIP_L, self.LORA_PREFIX_TEXT_ENCODER_CLIP_G, self.LORA_PREFIX_TEXT_ENCODER_T5][ + text_encoder_idx + ] + ) + + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if target_replace_modules is None or module.__class__.__name__ in target_replace_modules: + if target_replace_modules is None: # dirty hack for all modules + module = root_module # search all modules + + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + (name + "." if name else "") + child_name + lora_name = lora_name.replace(".", "_") + + force_incl_conv2d = False + if filter is not None: + if not filter in lora_name: + continue + force_incl_conv2d = include_conv2d_if_filter + + dim = None + alpha = None + + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha + + if is_mmdit and type_dims is not None: + # type_dims = [context_attn_dim, context_mlp_dim, context_mod_dim, x_attn_dim, x_mlp_dim, x_mod_dim] + identifier = [ + ("context_block", "attn"), + ("context_block", "mlp"), + ("context_block", "adaLN_modulation"), + ("x_block", "attn"), + ("x_block", "mlp"), + ("x_block", "adaLN_modulation"), + ] + for i, d in enumerate(type_dims): + if d is not None and all([id in lora_name for id in identifier[i]]): + dim = d # may be 0 for skip + break + + if is_mmdit and dim and self.train_block_indices is not None and "joint_blocks" in lora_name: + # "lora_unet_joint_blocks_0_x_block_attn_proj..." + block_index = int(lora_name.split("_")[4]) # bit dirty + if self.train_block_indices is not None and not self.train_block_indices[block_index]: + dim = 0 + + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + elif force_incl_conv2d: + # x_embedder + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None): + skipped.append(lora_name) + continue + + # qkv split + split_dims = None + if is_mmdit and split_qkv: + if "joint_blocks" in lora_name and "qkv" in lora_name: + split_dims = [qkv_dim // 3] * 3 + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + split_dims=split_dims, + ) + loras.append(lora) + + if target_replace_modules is None: + break # all modules are searched + return loras, skipped + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + index = i + if not train_t5xxl and index >= 2: # 0: CLIP-L, 1: CLIP-G, 2: T5XXL, so we skip T5XXL if train_t5xxl is False + break + + logger.info(f"create LoRA for Text Encoder {index+1}:") + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.") + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + + # create LoRA for U-Net + self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] + self.unet_loras, skipped_un = create_modules(True, None, unet, LoRANetwork.SD3_TARGET_REPLACE_MODULE) + + # emb_dims [context_embedder, t_embedder, x_embedder, y_embedder, final_mod, final_linear] + if self.emb_dims: + for filter, in_dim in zip( + [ + "context_embedder", + "_t_embedder", # don't use "t_embedder" because it's used in "context_embedder" + "x_embedder", + "y_embedder", + "final_layer_adaLN_modulation", + "final_layer_linear", + ], + self.emb_dims, + ): + # x_embedder is conv2d, so we need to include it + loras, _ = create_modules( + True, None, unet, None, filter=filter, default_dim=in_dim, include_conv2d_if_filter=filter == "x_embedder" + ) + # if len(loras) > 0: + # logger.info(f"create LoRA for {filter}: {len(loras)} modules.") + self.unet_loras.extend(loras) + + logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.") + if verbose: + for lora in self.unet_loras: + logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}") + + skipped = skipped_te + skipped_un + if verbose and len(skipped) > 0: + logger.warning( + f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + logger.info(f"\t{name}") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def load_state_dict(self, state_dict, strict=True): + # override to convert original weight to split qkv + if not self.split_qkv: + return super().load_state_dict(state_dict, strict) + + # split qkv + for key in list(state_dict.keys()): + if not ("joint_blocks" in key and "qkv" in key): + continue + + weight = state_dict[key] + lora_name = key.split(".")[0] + if "lora_down" in key and "weight" in key: + # dense weight (rank*3, in_dim) + split_weight = torch.chunk(weight, 3, dim=0) + for i, split_w in enumerate(split_weight): + state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w + + del state_dict[key] + # print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}") + elif "lora_up" in key and "weight" in key: + # sparse weight (out_dim=sum(split_dims), rank*3) + rank = weight.size(1) // 3 + i = 0 + split_dim = weight.shape[0] // 3 + for j in range(3): + state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dim, j * rank : (j + 1) * rank] + i += split_dim + del state_dict[key] + + # alpha is unchanged + + return super().load_state_dict(state_dict, strict) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + if not self.split_qkv: + return super().state_dict(destination, prefix, keep_vars) + + # merge qkv + state_dict = super().state_dict(destination, prefix, keep_vars) + new_state_dict = {} + for key in list(state_dict.keys()): + if not ("joint_blocks" in key and "qkv" in key): + new_state_dict[key] = state_dict[key] + continue + + if key not in state_dict: + continue # already merged + + lora_name = key.split(".")[0] + + # (rank, in_dim) * 3 + down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(3)] + # (split dim, rank) * 3 + up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(3)] + + alpha = state_dict.pop(f"{lora_name}.alpha") + + # merge down weight + down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) + + # merge up weight (sum of split_dim, rank*3) + split_dim, rank = up_weights[0].size() + qkv_dim = split_dim * 3 + up_weight = torch.zeros((qkv_dim, down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) + i = 0 + for j in range(3): + up_weight[i : i + split_dim, j * rank : (j + 1) * rank] = up_weights[j] + i += split_dim + + new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight + new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight + new_state_dict[f"{lora_name}.alpha"] = alpha + + # print( + # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" + # ) + print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha") + + return new_state_dict + + def apply_to(self, text_encoders, mmdit, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoders, mmdit, weights_sd, dtype=None, device=None): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if ( + key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP_L) + or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP_G) + or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5) + ): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_SD3): + apply_unet = True + + if apply_text_encoder: + logger.info("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + logger.info(f"weights are merged") + + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + + def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): + # make sure text_encoder_lr as list of three elements + # if float, use the same value for all three + if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): + text_encoder_lr = [default_lr, default_lr, default_lr] + elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int): + text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr), float(text_encoder_lr)] + elif len(text_encoder_lr) == 1: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0], text_encoder_lr[0]] + elif len(text_encoder_lr) == 2: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[1], text_encoder_lr[1]] + + self.requires_grad_(True) + + all_params = [] + lr_descriptions = [] + + def assemble_params(loras, lr, loraplus_ratio): + param_groups = {"lora": {}, "plus": {}} + for lora in loras: + for name, param in lora.named_parameters(): + if loraplus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + descriptions = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * loraplus_ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + + params.append(param_data) + descriptions.append("plus" if key == "plus" else "") + + return params, descriptions + + if self.text_encoder_loras: + loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio + + # split text encoder loras for te1 and te3 + te1_loras = [ + lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP_L) + ] + te2_loras = [ + lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP_G) + ] + te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)] + if len(te1_loras) > 0: + logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}") + params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions]) + if len(te2_loras) > 0: + logger.info(f"Text Encoder 2 (CLIP-G): {len(te2_loras)} modules, LR {text_encoder_lr[1]}") + params, descriptions = assemble_params(te2_loras, text_encoder_lr[1], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions]) + if len(te3_loras) > 0: + logger.info(f"Text Encoder 3 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[2]}") + params, descriptions = assemble_params(te3_loras, text_encoder_lr[2], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 3 " + (" " + d if d else "") for d in descriptions]) + + if self.unet_loras: + params, descriptions = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) + + return all_params, lr_descriptions + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from ..library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False + + def apply_max_norm_regularization(self, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + + state_dict = self.state_dict() + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms) diff --git a/networks/merge_lora.py b/networks/merge_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..f86aa4863c6b1e8d26bebc26bb3456feca1e766a --- /dev/null +++ b/networks/merge_lora.py @@ -0,0 +1,360 @@ +import math +import argparse +import os +import time +import torch +from safetensors.torch import load_file, save_file +from ..library import sai_model_spec, train_util +from ..library import model_util as model_util +import lora +from ..library.utils import setup_logging +setup_logging() +import logging +logger = logging.getLogger(__name__) + +def load_state_dict(file_name, dtype): + if os.path.splitext(file_name)[1] == ".safetensors": + sd = load_file(file_name) + metadata = train_util.load_metadata_from_safetensors(file_name) + else: + sd = torch.load(file_name, map_location="cpu") + metadata = {} + + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) + + return sd, metadata + + +def save_to_file(file_name, model, state_dict, dtype, metadata): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + if os.path.splitext(file_name)[1] == ".safetensors": + save_file(model, file_name, metadata=metadata) + else: + torch.save(model, file_name) + + +def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): + text_encoder.to(merge_dtype) + unet.to(merge_dtype) + + # create module map + name_to_module = {} + for i, root_module in enumerate([text_encoder, unet]): + if i == 0: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER + target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE + else: + prefix = lora.LoRANetwork.LORA_PREFIX_UNET + target_replace_modules = ( + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + ) + + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + name_to_module[lora_name] = child_module + + for model, ratio in zip(models, ratios): + logger.info(f"loading: {model}") + lora_sd, _ = load_state_dict(model, merge_dtype) + + logger.info(f"merging...") + for key in lora_sd.keys(): + if "lora_down" in key: + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + + # find original module for this lora + module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" + if module_name not in name_to_module: + logger.info(f"no module found for LoRA weight: {key}") + continue + module = name_to_module[module_name] + # logger.info(f"apply {key} to {module}") + + down_weight = lora_sd[key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + # W <- W + U * D + weight = module.weight + if len(weight.size()) == 2: + # linear + if len(up_weight.size()) == 4: # use linear projection mismatch + up_weight = up_weight.squeeze(3).squeeze(2) + down_weight = down_weight.squeeze(3).squeeze(2) + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + ratio * conved * scale + + module.weight = torch.nn.Parameter(weight) + + +def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): + base_alphas = {} # alpha for merged model + base_dims = {} + + merged_sd = {} + v2 = None + base_model = None + for model, ratio in zip(models, ratios): + logger.info(f"loading: {model}") + lora_sd, lora_metadata = load_state_dict(model, merge_dtype) + + if lora_metadata is not None: + if v2 is None: + v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # return string + if base_model is None: + base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) + + # get alpha and dim + alphas = {} # alpha for current model + dims = {} # dims for current model + for key in lora_sd.keys(): + if "alpha" in key: + lora_module_name = key[: key.rfind(".alpha")] + alpha = float(lora_sd[key].detach().numpy()) + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + elif "lora_down" in key: + lora_module_name = key[: key.rfind(".lora_down")] + dim = lora_sd[key].size()[0] + dims[lora_module_name] = dim + if lora_module_name not in base_dims: + base_dims[lora_module_name] = dim + + for lora_module_name in dims.keys(): + if lora_module_name not in alphas: + alpha = dims[lora_module_name] + alphas[lora_module_name] = alpha + if lora_module_name not in base_alphas: + base_alphas[lora_module_name] = alpha + + logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") + + # merge + logger.info(f"merging...") + for key in lora_sd.keys(): + if "alpha" in key: + continue + if "lora_up" in key and concat: + concat_dim = 1 + elif "lora_down" in key and concat: + concat_dim = 0 + else: + concat_dim = None + + lora_module_name = key[: key.rfind(".lora_")] + + base_alpha = base_alphas[lora_module_name] + alpha = alphas[lora_module_name] + + scale = math.sqrt(alpha / base_alpha) * ratio + scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 + + if key in merged_sd: + assert ( + merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None + ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" + if concat_dim is not None: + merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim) + else: + merged_sd[key] = merged_sd[key] + lora_sd[key] * scale + else: + merged_sd[key] = lora_sd[key] * scale + + # set alpha to sd + for lora_module_name, alpha in base_alphas.items(): + key = lora_module_name + ".alpha" + merged_sd[key] = torch.tensor(alpha) + if shuffle: + key_down = lora_module_name + ".lora_down.weight" + key_up = lora_module_name + ".lora_up.weight" + dim = merged_sd[key_down].shape[0] + perm = torch.randperm(dim) + merged_sd[key_down] = merged_sd[key_down][perm] + merged_sd[key_up] = merged_sd[key_up][:,perm] + + logger.info("merged model") + logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") + + # check all dims are same + dims_list = list(set(base_dims.values())) + alphas_list = list(set(base_alphas.values())) + all_same_dims = True + all_same_alphas = True + for dims in dims_list: + if dims != dims_list[0]: + all_same_dims = False + break + for alphas in alphas_list: + if alphas != alphas_list[0]: + all_same_alphas = False + break + + # build minimum metadata + dims = f"{dims_list[0]}" if all_same_dims else "Dynamic" + alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic" + metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, None) + + return merged_sd, metadata, v2 == "True" + + +def merge(args): + assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + + def str_to_dtype(p): + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 + return None + + merge_dtype = str_to_dtype(args.precision) + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + if args.sd_model is not None: + logger.info(f"loading SD model: {args.sd_model}") + + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) + + merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) + + if args.no_metadata: + sai_metadata = None + else: + merged_from = sai_model_spec.build_merged_from([args.sd_model] + args.models) + title = os.path.splitext(os.path.basename(args.save_to))[0] + sai_metadata = sai_model_spec.build_metadata( + None, + args.v2, + args.v2, + False, + False, + False, + time.time(), + title=title, + merged_from=merged_from, + is_stable_diffusion_ckpt=True, + ) + if args.v2: + # TODO read sai modelspec + logger.warning( + "Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" + ) + + logger.info(f"saving SD model to: {args.save_to}") + model_util.save_stable_diffusion_checkpoint( + args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae + ) + else: + state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) + + logger.info(f"calculating hashes and creating metadata...") + + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + if not args.no_metadata: + merged_from = sai_model_spec.build_merged_from(args.models) + title = os.path.splitext(os.path.basename(args.save_to))[0] + sai_metadata = sai_model_spec.build_metadata( + state_dict, v2, v2, False, True, False, time.time(), title=title, merged_from=merged_from + ) + if v2: + # TODO read sai modelspec + logger.warning( + "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" + ) + metadata.update(sai_metadata) + + logger.info(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む") + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ", + ) + parser.add_argument( + "--precision", + type=str, + default="float", + choices=["float", "fp16", "bf16"], + help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)", + ) + parser.add_argument( + "--sd_model", + type=str, + default=None, + help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする", + ) + parser.add_argument( + "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + ) + parser.add_argument( + "--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors" + ) + parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率") + parser.add_argument( + "--no_metadata", + action="store_true", + help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", + ) + parser.add_argument( + "--concat", + action="store_true", + help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / " + + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)", + ) + parser.add_argument( + "--shuffle", + action="store_true", + help="shuffle lora weight./ " + + "LoRAの重みをシャッフルする", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + merge(args) diff --git a/networks/resize_lora.py b/networks/resize_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..0409d5dfb9aabcc0c890c7b35364509c9593b3e5 --- /dev/null +++ b/networks/resize_lora.py @@ -0,0 +1,411 @@ +# Convert LoRA to different rank approximation (should only be used to go to lower rank) +# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py +# Thanks to cloneofsimo + +import os +import argparse +import torch +from safetensors.torch import load_file, save_file, safe_open +from tqdm import tqdm +import numpy as np + +from ..library import train_util +from ..library import model_util +from ..library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +MIN_SV = 1e-6 + +# Model save and load functions + + +def load_state_dict(file_name, dtype): + if model_util.is_safetensors(file_name): + sd = load_file(file_name) + with safe_open(file_name, framework="pt") as f: + metadata = f.metadata() + else: + sd = torch.load(file_name, map_location="cpu") + metadata = None + + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) + + return sd, metadata + + +def save_to_file(file_name, state_dict, dtype, metadata): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + if model_util.is_safetensors(file_name): + save_file(state_dict, file_name, metadata) + else: + torch.save(state_dict, file_name) + + +# Indexing functions + + +def index_sv_cumulative(S, target): + original_sum = float(torch.sum(S)) + cumulative_sums = torch.cumsum(S, dim=0) / original_sum + index = int(torch.searchsorted(cumulative_sums, target)) + 1 + index = max(1, min(index, len(S) - 1)) + + return index + + +def index_sv_fro(S, target): + S_squared = S.pow(2) + S_fro_sq = float(torch.sum(S_squared)) + sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq + index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 + index = max(1, min(index, len(S) - 1)) + + return index + + +def index_sv_ratio(S, target): + max_sv = S[0] + min_sv = max_sv / target + index = int(torch.sum(S > min_sv).item()) + index = max(1, min(index, len(S) - 1)) + + return index + + +# Modified from Kohaku-blueleaf's extract/merge functions +def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): + out_size, in_size, kernel_size, _ = weight.size() + U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device)) + + param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) + lora_rank = param_dict["new_rank"] + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu() + param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu() + del U, S, Vh, weight + return param_dict + + +def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): + out_size, in_size = weight.size() + + U, S, Vh = torch.linalg.svd(weight.to(device)) + + param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) + lora_rank = param_dict["new_rank"] + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu() + param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu() + del U, S, Vh, weight + return param_dict + + +def merge_conv(lora_down, lora_up, device): + in_rank, in_size, kernel_size, k_ = lora_down.shape + out_size, out_rank, _, _ = lora_up.shape + assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch" + + lora_down = lora_down.to(device) + lora_up = lora_up.to(device) + + merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1) + weight = merged.reshape(out_size, in_size, kernel_size, kernel_size) + del lora_up, lora_down + return weight + + +def merge_linear(lora_down, lora_up, device): + in_rank, in_size = lora_down.shape + out_size, out_rank = lora_up.shape + assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch" + + lora_down = lora_down.to(device) + lora_up = lora_up.to(device) + + weight = lora_up @ lora_down + del lora_up, lora_down + return weight + + +# Calculate new rank + + +def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): + param_dict = {} + + if dynamic_method == "sv_ratio": + # Calculate new dim and alpha based off ratio + new_rank = index_sv_ratio(S, dynamic_param) + 1 + new_alpha = float(scale * new_rank) + + elif dynamic_method == "sv_cumulative": + # Calculate new dim and alpha based off cumulative sum + new_rank = index_sv_cumulative(S, dynamic_param) + 1 + new_alpha = float(scale * new_rank) + + elif dynamic_method == "sv_fro": + # Calculate new dim and alpha based off sqrt sum of squares + new_rank = index_sv_fro(S, dynamic_param) + 1 + new_alpha = float(scale * new_rank) + else: + new_rank = rank + new_alpha = float(scale * new_rank) + + if S[0] <= MIN_SV: # Zero matrix, set dim to 1 + new_rank = 1 + new_alpha = float(scale * new_rank) + elif new_rank > rank: # cap max rank at rank + new_rank = rank + new_alpha = float(scale * new_rank) + + # Calculate resize info + s_sum = torch.sum(torch.abs(S)) + s_rank = torch.sum(torch.abs(S[:new_rank])) + + S_squared = S.pow(2) + s_fro = torch.sqrt(torch.sum(S_squared)) + s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank])) + fro_percent = float(s_red_fro / s_fro) + + param_dict["new_rank"] = new_rank + param_dict["new_alpha"] = new_alpha + param_dict["sum_retained"] = (s_rank) / s_sum + param_dict["fro_retained"] = fro_percent + param_dict["max_ratio"] = S[0] / S[new_rank - 1] + + return param_dict + + +def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): + network_alpha = None + network_dim = None + verbose_str = "\n" + fro_list = [] + + # Extract loaded lora dim and alpha + for key, value in lora_sd.items(): + if network_alpha is None and "alpha" in key: + network_alpha = value + if network_dim is None and "lora_down" in key and len(value.size()) == 2: + network_dim = value.size()[0] + if network_alpha is not None and network_dim is not None: + break + if network_alpha is None: + network_alpha = network_dim + + scale = network_alpha / network_dim + + if dynamic_method: + logger.info( + f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}" + ) + + lora_down_weight = None + lora_up_weight = None + + o_lora_sd = lora_sd.copy() + block_down_name = None + block_up_name = None + + with torch.no_grad(): + for key, value in tqdm(lora_sd.items()): + weight_name = None + if "lora_down" in key: + block_down_name = key.rsplit(".lora_down", 1)[0] + weight_name = key.rsplit(".", 1)[-1] + lora_down_weight = value + else: + continue + + # find corresponding lora_up and alpha + block_up_name = block_down_name + lora_up_weight = lora_sd.get(block_up_name + ".lora_up." + weight_name, None) + lora_alpha = lora_sd.get(block_down_name + ".alpha", None) + + weights_loaded = lora_down_weight is not None and lora_up_weight is not None + + if weights_loaded: + + conv2d = len(lora_down_weight.size()) == 4 + if lora_alpha is None: + scale = 1.0 + else: + scale = lora_alpha / lora_down_weight.size()[0] + + if conv2d: + full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) + param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale) + else: + full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device) + param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) + + if verbose: + max_ratio = param_dict["max_ratio"] + sum_retained = param_dict["sum_retained"] + fro_retained = param_dict["fro_retained"] + if not np.isnan(fro_retained): + fro_list.append(float(fro_retained)) + + verbose_str += f"{block_down_name:75} | " + verbose_str += ( + f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}" + ) + + if verbose and dynamic_method: + verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n" + else: + verbose_str += "\n" + + new_alpha = param_dict["new_alpha"] + o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype) + + block_down_name = None + block_up_name = None + lora_down_weight = None + lora_up_weight = None + weights_loaded = False + del param_dict + + if verbose: + print(verbose_str) + print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") + logger.info("resizing complete") + return o_lora_sd, network_dim, new_alpha + + +def resize(args): + if args.save_to is None or not ( + args.save_to.endswith(".ckpt") + or args.save_to.endswith(".pt") + or args.save_to.endswith(".pth") + or args.save_to.endswith(".safetensors") + ): + raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.") + + args.new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank + + def str_to_dtype(p): + if p == "float": + return torch.float + if p == "fp16": + return torch.float16 + if p == "bf16": + return torch.bfloat16 + return None + + if args.dynamic_method and not args.dynamic_param: + raise Exception("If using dynamic_method, then dynamic_param is required") + + merge_dtype = str_to_dtype("float") # matmul method above only seems to work in float32 + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + logger.info("loading Model...") + lora_sd, metadata = load_state_dict(args.model, merge_dtype) + + logger.info("Resizing Lora...") + state_dict, old_dim, new_alpha = resize_lora_model( + lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose + ) + + # update metadata + if metadata is None: + metadata = {} + + comment = metadata.get("ss_training_comment", "") + + if not args.dynamic_method: + conv_desc = "" if args.new_rank == args.new_conv_rank else f" (conv: {args.new_conv_rank})" + metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}{conv_desc}; {comment}" + metadata["ss_network_dim"] = str(args.new_rank) + metadata["ss_network_alpha"] = str(new_alpha) + else: + metadata["ss_training_comment"] = ( + f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}" + ) + metadata["ss_network_dim"] = "Dynamic" + metadata["ss_network_alpha"] = "Dynamic" + + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + logger.info(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, save_dtype, metadata) + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + parser.add_argument( + "--save_precision", + type=str, + default=None, + choices=[None, "float", "fp16", "bf16"], + help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat", + ) + parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)") + parser.add_argument( + "--new_conv_rank", + type=int, + default=None, + help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ", + ) + parser.add_argument( + "--save_to", + type=str, + default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", + ) + parser.add_argument( + "--model", + type=str, + default=None, + help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors", + ) + parser.add_argument( + "--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う" + ) + parser.add_argument( + "--verbose", action="store_true", help="Display verbose resizing information / rank変更時の詳細情報を出力する" + ) + parser.add_argument( + "--dynamic_method", + type=str, + default=None, + choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"], + help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank", + ) + parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction") + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + resize(args) diff --git a/nodes.py b/nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..49a59721a984dc785b615d56693a283a71d591d6 --- /dev/null +++ b/nodes.py @@ -0,0 +1,1798 @@ +import os +import torch +from torchvision import transforms + +import folder_paths +import comfy.model_management as mm +import comfy.utils +import toml +import json +import time +import shutil +import shlex + +from pathlib import Path +script_directory = os.path.dirname(os.path.abspath(__file__)) + +from .flux_train_network_comfy import FluxNetworkTrainer +from .library import flux_train_utils as flux_train_utils +from .flux_train_comfy import FluxTrainer +from .flux_train_comfy import setup_parser as train_setup_parser +from .library.device_utils import init_ipex +init_ipex() + +from .library import train_util +from .train_network import setup_parser as train_network_setup_parser +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import io +from PIL import Image + +import logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +class FluxTrainModelSelect: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "transformer": (folder_paths.get_filename_list("unet"), ), + "vae": (folder_paths.get_filename_list("vae"), ), + "clip_l": (folder_paths.get_filename_list("clip"), ), + "t5": (folder_paths.get_filename_list("clip"), ), + }, + "optional": { + "lora_path": ("STRING",{"multiline": True, "forceInput": True, "default": "", "tooltip": "pre-trained LoRA path to load (network_weights)"}), + } + } + + RETURN_TYPES = ("TRAIN_FLUX_MODELS",) + RETURN_NAMES = ("flux_models",) + FUNCTION = "loadmodel" + CATEGORY = "FluxTrainer" + + def loadmodel(self, transformer, vae, clip_l, t5, lora_path=""): + + transformer_path = folder_paths.get_full_path("unet", transformer) + vae_path = folder_paths.get_full_path("vae", vae) + clip_path = folder_paths.get_full_path("clip", clip_l) + t5_path = folder_paths.get_full_path("clip", t5) + + flux_models = { + "transformer": transformer_path, + "vae": vae_path, + "clip_l": clip_path, + "t5": t5_path, + "lora_path": lora_path + } + + return (flux_models,) + +class TrainDatasetGeneralConfig: + queue_counter = 0 + @classmethod + def IS_CHANGED(s, reset_on_queue=False, **kwargs): + if reset_on_queue: + s.queue_counter += 1 + print(f"queue_counter: {s.queue_counter}") + return s.queue_counter + @classmethod + def INPUT_TYPES(s): + return {"required": { + "color_aug": ("BOOLEAN",{"default": False, "tooltip": "enable weak color augmentation"}), + "flip_aug": ("BOOLEAN",{"default": False, "tooltip": "enable horizontal flip augmentation"}), + "shuffle_caption": ("BOOLEAN",{"default": False, "tooltip": "shuffle caption"}), + "caption_dropout_rate": ("FLOAT",{"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01,"tooltip": "tag dropout rate"}), + "alpha_mask": ("BOOLEAN",{"default": False, "tooltip": "use alpha channel as mask for training"}), + }, + "optional": { + "reset_on_queue": ("BOOLEAN",{"default": False, "tooltip": "Force refresh of everything for cleaner queueing"}), + "caption_extension": ("STRING",{"default": ".txt", "tooltip": "extension for caption files"}), + } + } + + RETURN_TYPES = ("JSON",) + RETURN_NAMES = ("dataset_general",) + FUNCTION = "create_config" + CATEGORY = "FluxTrainer" + + def create_config(self, shuffle_caption, caption_dropout_rate, color_aug, flip_aug, alpha_mask, reset_on_queue=False, caption_extension=".txt"): + + dataset = { + "general": { + "shuffle_caption": shuffle_caption, + "caption_extension": caption_extension, + "keep_tokens_separator": "|||", + "caption_dropout_rate": caption_dropout_rate, + "color_aug": color_aug, + "flip_aug": flip_aug, + }, + "datasets": [] + } + dataset_json = json.dumps(dataset, indent=2) + #print(dataset_json) + dataset_config = { + "datasets": dataset_json, + "alpha_mask": alpha_mask + } + return (dataset_config,) + +class TrainDatasetRegularization: + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "dataset_path": ("STRING",{"multiline": True, "default": "", "tooltip": "path to dataset, root is the 'ComfyUI' folder, with windows portable 'ComfyUI_windows_portable'"}), + "class_tokens": ("STRING",{"multiline": True, "default": "", "tooltip": "aka trigger word, if specified, will be added to the start of each caption, if no captions exist, will be used on it's own"}), + "num_repeats": ("INT", {"default": 1, "min": 1, "tooltip": "number of times to repeat dataset for an epoch"}), + }, + } + + RETURN_TYPES = ("JSON",) + RETURN_NAMES = ("subset",) + FUNCTION = "create_config" + CATEGORY = "FluxTrainer" + + def create_config(self, dataset_path, class_tokens, num_repeats): + + reg_subset = { + "image_dir": dataset_path, + "class_tokens": class_tokens, + "num_repeats": num_repeats, + "is_reg": True + } + + return reg_subset, + +class TrainDatasetAdd: + def __init__(self): + self.previous_dataset_signature = None + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "dataset_config": ("JSON",), + "width": ("INT",{"min": 64, "default": 1024, "tooltip": "base resolution width"}), + "height": ("INT",{"min": 64, "default": 1024, "tooltip": "base resolution height"}), + "batch_size": ("INT",{"min": 1, "default": 2, "tooltip": "Higher batch size uses more memory and generalizes the training more"}), + "dataset_path": ("STRING",{"multiline": True, "default": "", "tooltip": "path to dataset, root is the 'ComfyUI' folder, with windows portable 'ComfyUI_windows_portable'"}), + "class_tokens": ("STRING",{"multiline": True, "default": "", "tooltip": "aka trigger word, if specified, will be added to the start of each caption, if no captions exist, will be used on it's own"}), + "enable_bucket": ("BOOLEAN",{"default": True, "tooltip": "enable buckets for multi aspect ratio training"}), + "bucket_no_upscale": ("BOOLEAN",{"default": False, "tooltip": "don't allow upscaling when bucketing"}), + "num_repeats": ("INT", {"default": 1, "min": 1, "tooltip": "number of times to repeat dataset for an epoch"}), + "min_bucket_reso": ("INT", {"default": 256, "min": 64, "max": 4096, "step": 8, "tooltip": "min bucket resolution"}), + "max_bucket_reso": ("INT", {"default": 1024, "min": 64, "max": 4096, "step": 8, "tooltip": "max bucket resolution"}), + }, + "optional": { + "regularization": ("JSON", {"tooltip": "reg data dir"}), + } + } + + RETURN_TYPES = ("JSON",) + RETURN_NAMES = ("dataset",) + FUNCTION = "create_config" + CATEGORY = "FluxTrainer" + + def create_config(self, dataset_config, dataset_path, class_tokens, width, height, batch_size, num_repeats, enable_bucket, + bucket_no_upscale, min_bucket_reso, max_bucket_reso, regularization=None): + + new_dataset = { + "resolution": (width, height), + "batch_size": batch_size, + "enable_bucket": enable_bucket, + "bucket_no_upscale": bucket_no_upscale, + "min_bucket_reso": min_bucket_reso, + "max_bucket_reso": max_bucket_reso, + "subsets": [ + { + "image_dir": dataset_path, + "class_tokens": class_tokens, + "num_repeats": num_repeats + } + ] + } + if regularization is not None: + new_dataset["subsets"].append(regularization) + + # Generate a signature for the new dataset + new_dataset_signature = self.generate_signature(new_dataset) + + # Load the existing datasets + existing_datasets = json.loads(dataset_config["datasets"]) + + # Remove the previously added dataset if it exists + if self.previous_dataset_signature: + existing_datasets["datasets"] = [ + ds for ds in existing_datasets["datasets"] + if self.generate_signature(ds) != self.previous_dataset_signature + ] + + # Add the new dataset + existing_datasets["datasets"].append(new_dataset) + + # Store the new dataset signature for future runs + self.previous_dataset_signature = new_dataset_signature + + # Convert back to JSON and update dataset_config + updated_dataset_json = json.dumps(existing_datasets, indent=2) + dataset_config["datasets"] = updated_dataset_json + + return dataset_config, + + def generate_signature(self, dataset): + # Create a unique signature for the dataset based on its attributes + return json.dumps(dataset, sort_keys=True) + +class OptimizerConfig: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "optimizer_type": (["adamw8bit", "adamw","prodigy", "CAME", "Lion8bit", "Lion", "adamwschedulefree", "sgdschedulefree", "AdEMAMix8bit", "PagedAdEMAMix8bit", "ProdigyPlusScheduleFree"], {"default": "adamw8bit", "tooltip": "optimizer type"}), + "max_grad_norm": ("FLOAT",{"default": 1.0, "min": 0.0, "tooltip": "gradient clipping"}), + "lr_scheduler": (["constant", "cosine", "cosine_with_restarts", "polynomial", "constant_with_warmup"], {"default": "constant", "tooltip": "learning rate scheduler"}), + "lr_warmup_steps": ("INT",{"default": 0, "min": 0, "tooltip": "learning rate warmup steps"}), + "lr_scheduler_num_cycles": ("INT",{"default": 1, "min": 1, "tooltip": "learning rate scheduler num cycles"}), + "lr_scheduler_power": ("FLOAT",{"default": 1.0, "min": 0.0, "tooltip": "learning rate scheduler power"}), + "min_snr_gamma": ("FLOAT",{"default": 5.0, "min": 0.0, "step": 0.01, "tooltip": "gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by the paper"}), + "extra_optimizer_args": ("STRING",{"multiline": True, "default": "", "tooltip": "additional optimizer args"}), + }, + } + + RETURN_TYPES = ("ARGS",) + RETURN_NAMES = ("optimizer_settings",) + FUNCTION = "create_config" + CATEGORY = "FluxTrainer" + + def create_config(self, min_snr_gamma, extra_optimizer_args, **kwargs): + kwargs["min_snr_gamma"] = min_snr_gamma if min_snr_gamma != 0.0 else None + kwargs["optimizer_args"] = [arg.strip() for arg in extra_optimizer_args.strip().split('|') if arg.strip()] + return (kwargs,) + +class OptimizerConfigAdafactor: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "max_grad_norm": ("FLOAT",{"default": 0.0, "min": 0.0, "tooltip": "gradient clipping"}), + "lr_scheduler": (["constant", "cosine", "cosine_with_restarts", "polynomial", "constant_with_warmup", "adafactor"], {"default": "constant_with_warmup", "tooltip": "learning rate scheduler"}), + "lr_warmup_steps": ("INT",{"default": 0, "min": 0, "tooltip": "learning rate warmup steps"}), + "lr_scheduler_num_cycles": ("INT",{"default": 1, "min": 1, "tooltip": "learning rate scheduler num cycles"}), + "lr_scheduler_power": ("FLOAT",{"default": 1.0, "min": 0.0, "tooltip": "learning rate scheduler power"}), + "relative_step": ("BOOLEAN",{"default": False, "tooltip": "relative step"}), + "scale_parameter": ("BOOLEAN",{"default": False, "tooltip": "scale parameter"}), + "warmup_init": ("BOOLEAN",{"default": False, "tooltip": "warmup init"}), + "clip_threshold": ("FLOAT",{"default": 1.0, "min": 0.0, "tooltip": "clip threshold"}), + "min_snr_gamma": ("FLOAT",{"default": 5.0, "min": 0.0, "step": 0.01, "tooltip": "gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by the paper"}), + "extra_optimizer_args": ("STRING",{"multiline": True, "default": "", "tooltip": "additional optimizer args"}), + }, + } + + RETURN_TYPES = ("ARGS",) + RETURN_NAMES = ("optimizer_settings",) + FUNCTION = "create_config" + CATEGORY = "FluxTrainer" + + def create_config(self, relative_step, scale_parameter, warmup_init, clip_threshold, min_snr_gamma, extra_optimizer_args, **kwargs): + kwargs["optimizer_type"] = "adafactor" + extra_args = [arg.strip() for arg in extra_optimizer_args.strip().split('|') if arg.strip()] + node_args = [ + f"relative_step={relative_step}", + f"scale_parameter={scale_parameter}", + f"warmup_init={warmup_init}", + f"clip_threshold={clip_threshold}" + ] + kwargs["optimizer_args"] = node_args + extra_args + kwargs["min_snr_gamma"] = min_snr_gamma if min_snr_gamma != 0.0 else None + + return (kwargs,) + +class FluxTrainerLossConfig: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "loss_type": (["l2", "huber","smooth_l1"], {"default": "huber", "tooltip": "The type of loss function to use"}), + "huber_schedule": (["snr", "exponential", "constant"], {"default": "exponential", "tooltip": "The scheduling method for Huber loss (constant, exponential, or SNR-based). Only used when loss_type is 'huber' or 'smooth_l1'. default is snr"}), + "huber_c": ("FLOAT",{"default": 0.25, "min": 0.0, "step": 0.01, "tooltip": "The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1"}), + "huber_scale": ("FLOAT",{"default": 1.75, "min": 0.0, "step": 0.01, "tooltip": "The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0"}), + }, + } + + RETURN_TYPES = ("ARGS",) + RETURN_NAMES = ("loss_args",) + FUNCTION = "create_config" + CATEGORY = "FluxTrainer" + + def create_config(self, **kwargs): + return (kwargs,) + +class OptimizerConfigProdigy: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "max_grad_norm": ("FLOAT",{"default": 0.0, "min": 0.0, "tooltip": "gradient clipping"}), + "lr_scheduler": (["constant", "cosine", "cosine_with_restarts", "polynomial", "constant_with_warmup", "adafactor"], {"default": "constant", "tooltip": "learning rate scheduler"}), + "lr_warmup_steps": ("INT",{"default": 0, "min": 0, "tooltip": "learning rate warmup steps"}), + "lr_scheduler_num_cycles": ("INT",{"default": 1, "min": 1, "tooltip": "learning rate scheduler num cycles"}), + "lr_scheduler_power": ("FLOAT",{"default": 1.0, "min": 0.0, "tooltip": "learning rate scheduler power"}), + "weight_decay": ("FLOAT",{"default": 0.0, "step": 0.0001, "tooltip": "weight decay (L2 penalty)"}), + "decouple": ("BOOLEAN",{"default": True, "tooltip": "use AdamW style weight decay"}), + "use_bias_correction": ("BOOLEAN",{"default": False, "tooltip": "turn on Adam's bias correction"}), + "min_snr_gamma": ("FLOAT",{"default": 5.0, "min": 0.0, "step": 0.01, "tooltip": "gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by the paper"}), + "extra_optimizer_args": ("STRING",{"multiline": True, "default": "", "tooltip": "additional optimizer args"}), + }, + } + + RETURN_TYPES = ("ARGS",) + RETURN_NAMES = ("optimizer_settings",) + FUNCTION = "create_config" + CATEGORY = "FluxTrainer" + + def create_config(self, weight_decay, decouple, min_snr_gamma, use_bias_correction, extra_optimizer_args, **kwargs): + kwargs["optimizer_type"] = "prodigy" + extra_args = [arg.strip() for arg in extra_optimizer_args.strip().split('|') if arg.strip()] + node_args = [ + f"weight_decay={weight_decay}", + f"decouple={decouple}", + f"use_bias_correction={use_bias_correction}" + ] + kwargs["optimizer_args"] = node_args + extra_args + kwargs["min_snr_gamma"] = min_snr_gamma if min_snr_gamma != 0.0 else None + + return (kwargs,) + +class TrainNetworkConfig: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "network_type": (["lora", "LyCORIS/LoKr", "LyCORIS/Locon", "LyCORIS/LoHa"], {"default": "lora", "tooltip": "network type"}), + "lycoris_preset": (["full", "full-lin", "attn-mlp", "attn-only"], {"default": "attn-mlp"}), + "factor": ("INT",{"default": -1, "min": -1, "max": 16, "step": 1, "tooltip": "LoKr factor"}), + "extra_network_args": ("STRING",{"multiline": True, "default": "", "tooltip": "additional network args"}), + }, + } + + RETURN_TYPES = ("NETWORK_CONFIG",) + RETURN_NAMES = ("network_config",) + FUNCTION = "create_config" + CATEGORY = "FluxTrainer" + + def create_config(self, network_type, extra_network_args, lycoris_preset, factor): + + extra_args = [arg.strip() for arg in extra_network_args.strip().split('|') if arg.strip()] + + if network_type == "lora": + network_module = ".networks.lora" + elif network_type == "LyCORIS/LoKr": + network_module = ".lycoris.kohya" + algo = "lokr" + elif network_type == "LyCORIS/Locon": + network_module = ".lycoris.kohya" + algo = "locon" + elif network_type == "LyCORIS/LoHa": + network_module = ".lycoris.kohya" + algo = "loha" + + network_args = [ + f"algo={algo}", + f"factor={factor}", + f"preset={lycoris_preset}" + ] + network_config = { + "network_module": network_module, + "network_args": network_args + extra_args + } + + return (network_config,) + +class OptimizerConfigProdigyPlusScheduleFree: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "lr": ("FLOAT",{"default": 1.0, "min": 0.0, "step": 1e-7, "tooltip": "Learning rate adjustment parameter. Increases or decreases the Prodigy learning rate."}), + "max_grad_norm": ("FLOAT",{"default": 0.0, "min": 0.0, "tooltip": "gradient clipping"}), + "prodigy_steps": ("INT",{"default": 0, "min": 0, "tooltip": "Freeze Prodigy stepsize adjustments after a certain optimiser step."}), + "d0": ("FLOAT",{"default": 1e-6, "min": 0.0,"step": 1e-7, "tooltip": "initial learning rate"}), + "d_coeff": ("FLOAT",{"default": 1.0, "min": 0.0, "step": 1e-7, "tooltip": "Coefficient in the expression for the estimate of d (default 1.0). Values such as 0.5 and 2.0 typically work as well. Changing this parameter is the preferred way to tune the method."}), + "split_groups": ("BOOLEAN",{"default": True, "tooltip": "Track individual adaptation values for each parameter group."}), + #"beta3": ("FLOAT",{"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.0001, "tooltip": " Coefficient for computing the Prodigy stepsize using running averages. If set to None, uses the value of square root of beta2 (default: None)."}), + #"beta4": ("FLOAT",{"default": 0, "min": 0.0, "max": 1.0, "step": 0.0001, "tooltip": "Coefficient for updating the learning rate from Prodigy's adaptive stepsize. Smooths out spikes in learning rate adjustments. If set to None, beta1 is used instead. (default 0, which disables smoothing and uses original Prodigy behaviour)."}), + "use_bias_correction": ("BOOLEAN",{"default": False, "tooltip": "Use the RAdam variant of schedule-free"}), + "min_snr_gamma": ("FLOAT",{"default": 5.0, "min": 0.0, "step": 0.01, "tooltip": "gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by the paper"}), + "use_stableadamw": ("BOOLEAN",{"default": True, "tooltip": "Scales parameter updates by the root-mean-square of the normalised gradient, in essence identical to Adafactor's gradient scaling. Set to False if the adaptive learning rate never improves."}), + "use_cautious" : ("BOOLEAN",{"default": False, "tooltip": "Experimental. Perform 'cautious' updates, as proposed in https://arxiv.org/pdf/2411.16085. Modifies the update to isolate and boost values that align with the current gradient."}), + "use_adopt": ("BOOLEAN",{"default": False, "tooltip": "Experimental. Performs a modified step where the second moment is updated after the parameter update, so as not to include the current gradient in the denominator. This is a partial implementation of ADOPT (https://arxiv.org/abs/2411.02853), as we don't have a first moment to use for the update."}), + "use_grams": ("BOOLEAN",{"default": False, "tooltip": "Perform 'grams' updates, as proposed in https://arxiv.org/abs/2412.17107. Modifies the update using sign operations that align with the current gradient. Note that we do not have access to a first moment, so this deviates from the paper (we apply the sign directly to the update). May have a limited effect."}), + "stochastic_rounding": ("BOOLEAN",{"default": True, "tooltip": "Use stochastic rounding for bfloat16 weights"}), + "use_orthograd": ("BOOLEAN",{"default": False, "tooltip": "Experimental. Updates weights using the component of the gradient that is orthogonal to the current weight direction, as described in (https://arxiv.org/pdf/2501.04697). Can help prevent overfitting and improve generalisation."}), + "use_focus ": ("BOOLEAN",{"default": False, "tooltip": "Experimental. Modifies the update step to better handle noise at large step sizes. (https://arxiv.org/abs/2501.12243). This method is incompatible with factorisation, Muon and Adam-atan2."}), + "extra_optimizer_args": ("STRING",{"multiline": True, "default": "", "tooltip": "additional optimizer args"}), + }, + } + + RETURN_TYPES = ("ARGS",) + RETURN_NAMES = ("optimizer_settings",) + FUNCTION = "create_config" + CATEGORY = "FluxTrainer" + + def create_config(self, min_snr_gamma, use_bias_correction, extra_optimizer_args, **kwargs): + kwargs["optimizer_type"] = "ProdigyPlusScheduleFree" + kwargs["lr_scheduler"] = "constant" + extra_args = [arg.strip() for arg in extra_optimizer_args.strip().split('|') if arg.strip()] + node_args = [ + f"use_bias_correction={use_bias_correction}", + ] + kwargs["optimizer_args"] = node_args + extra_args + kwargs["min_snr_gamma"] = min_snr_gamma if min_snr_gamma != 0.0 else None + + return (kwargs,) + +class InitFluxLoRATraining: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "flux_models": ("TRAIN_FLUX_MODELS",), + "dataset": ("JSON",), + "optimizer_settings": ("ARGS",), + "output_name": ("STRING", {"default": "flux_lora", "multiline": False}), + "output_dir": ("STRING", {"default": "flux_trainer_output", "multiline": False, "tooltip": "path to dataset, root is the 'ComfyUI' folder, with windows portable 'ComfyUI_windows_portable'"}), + "network_dim": ("INT", {"default": 4, "min": 1, "max": 100000, "step": 1, "tooltip": "network dim"}), + "network_alpha": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2048.0, "step": 0.01, "tooltip": "network alpha"}), + "learning_rate": ("FLOAT", {"default": 4e-4, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "learning rate"}), + "max_train_steps": ("INT", {"default": 1500, "min": 1, "max": 100000, "step": 1, "tooltip": "max number of training steps"}), + "apply_t5_attn_mask": ("BOOLEAN", {"default": True, "tooltip": "apply t5 attention mask"}), + "cache_latents": (["disk", "memory", "disabled"], {"tooltip": "caches text encoder outputs"}), + "cache_text_encoder_outputs": (["disk", "memory", "disabled"], {"tooltip": "caches text encoder outputs"}), + "blocks_to_swap": ("INT", {"default": 0, "tooltip": "Previously known as split_mode, number of blocks to swap to save memory, default to enable is 18"}), + "weighting_scheme": (["logit_normal", "sigma_sqrt", "mode", "cosmap", "none"],), + "logit_mean": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "mean to use when using the logit_normal weighting scheme"}), + "logit_std": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01,"tooltip": "std to use when using the logit_normal weighting scheme"}), + "mode_scale": ("FLOAT", {"default": 1.29, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Scale of mode weighting scheme. Only effective when using the mode as the weighting_scheme"}), + "timestep_sampling": (["sigmoid", "uniform", "sigma", "shift", "flux_shift"], {"tooltip": "Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid (recommend value of 3.1582 for discrete_flow_shift)"}), + "sigmoid_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1, "tooltip": "Scale factor for sigmoid timestep sampling (only used when timestep-sampling is sigmoid"}), + "model_prediction_type": (["raw", "additive", "sigma_scaled"], {"tooltip": "How to interpret and process the model prediction: raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."}), + "guidance_scale": ("FLOAT", {"default": 1.0, "min": 1.0, "max": 32.0, "step": 0.01, "tooltip": "guidance scale, for Flux training should be 1.0"}), + "discrete_flow_shift": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.0001, "tooltip": "for the Euler Discrete Scheduler, default is 3.0"}), + "highvram": ("BOOLEAN", {"default": False, "tooltip": "memory mode"}), + "fp8_base": ("BOOLEAN", {"default": True, "tooltip": "use fp8 for base model"}), + "gradient_dtype": (["fp32", "fp16", "bf16"], {"default": "fp32", "tooltip": "the actual dtype training uses"}), + "save_dtype": (["fp32", "fp16", "bf16", "fp8_e4m3fn", "fp8_e5m2"], {"default": "bf16", "tooltip": "the dtype to save checkpoints as"}), + "attention_mode": (["sdpa", "xformers", "disabled"], {"default": "sdpa", "tooltip": "memory efficient attention mode"}), + "sample_prompts": ("STRING", {"multiline": True, "default": "illustration of a kitten | photograph of a turtle", "tooltip": "validation sample prompts, for multiple prompts, separate by `|`"}), + }, + "optional": { + "additional_args": ("STRING", {"multiline": True, "default": "", "tooltip": "additional args to pass to the training command"}), + "resume_args": ("ARGS", {"default": "", "tooltip": "resume args to pass to the training command"}), + "train_text_encoder": (['disabled', 'clip_l', 'clip_l_fp8', 'clip_l+T5', 'clip_l+T5_fp8'], {"default": 'disabled', "tooltip": "also train the selected text encoders using specified dtype, T5 can not be trained without clip_l"}), + "clip_l_lr": ("FLOAT", {"default": 0, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "text encoder learning rate"}), + "T5_lr": ("FLOAT", {"default": 0, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "text encoder learning rate"}), + "block_args": ("ARGS", {"default": "", "tooltip": "limit the blocks used in the LoRA"}), + "gradient_checkpointing": (["enabled", "enabled_with_cpu_offloading", "disabled"], {"default": "enabled", "tooltip": "use gradient checkpointing"}), + "loss_args": ("ARGS", {"default": "", "tooltip": "loss args"}), + "network_config": ("NETWORK_CONFIG", {"tooltip": "additional network config"}), + }, + "hidden": { + "prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO" + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "INT", "KOHYA_ARGS",) + RETURN_NAMES = ("network_trainer", "epochs_count", "args",) + FUNCTION = "init_training" + CATEGORY = "FluxTrainer" + + def init_training(self, flux_models, dataset, optimizer_settings, sample_prompts, output_name, attention_mode, + gradient_dtype, save_dtype, additional_args=None, resume_args=None, train_text_encoder='disabled', + block_args=None, gradient_checkpointing="enabled", prompt=None, extra_pnginfo=None, clip_l_lr=0, T5_lr=0, loss_args=None, network_config=None, **kwargs): + mm.soft_empty_cache() + + output_dir = os.path.abspath(kwargs.get("output_dir")) + os.makedirs(output_dir, exist_ok=True) + + total, used, free = shutil.disk_usage(output_dir) + + required_free_space = 2 * (2**30) + if free <= required_free_space: + raise ValueError(f"Insufficient disk space. Required: {required_free_space/2**30}GB. Available: {free/2**30}GB") + + dataset_config = dataset["datasets"] + dataset_toml = toml.dumps(json.loads(dataset_config)) + + parser = train_network_setup_parser() + flux_train_utils.add_flux_train_arguments(parser) + + if additional_args is not None: + print(f"additional_args: {additional_args}") + args, _ = parser.parse_known_args(args=shlex.split(additional_args)) + else: + args, _ = parser.parse_known_args() + + if kwargs.get("cache_latents") == "memory": + kwargs["cache_latents"] = True + kwargs["cache_latents_to_disk"] = False + elif kwargs.get("cache_latents") == "disk": + kwargs["cache_latents"] = True + kwargs["cache_latents_to_disk"] = True + kwargs["caption_dropout_rate"] = 0.0 + kwargs["shuffle_caption"] = False + kwargs["token_warmup_step"] = 0.0 + kwargs["caption_tag_dropout_rate"] = 0.0 + else: + kwargs["cache_latents"] = False + kwargs["cache_latents_to_disk"] = False + + if kwargs.get("cache_text_encoder_outputs") == "memory": + kwargs["cache_text_encoder_outputs"] = True + kwargs["cache_text_encoder_outputs_to_disk"] = False + elif kwargs.get("cache_text_encoder_outputs") == "disk": + kwargs["cache_text_encoder_outputs"] = True + kwargs["cache_text_encoder_outputs_to_disk"] = True + else: + kwargs["cache_text_encoder_outputs"] = False + kwargs["cache_text_encoder_outputs_to_disk"] = False + + if '|' in sample_prompts: + prompts = sample_prompts.split('|') + else: + prompts = [sample_prompts] + + config_dict = { + "sample_prompts": prompts, + "save_precision": save_dtype, + "mixed_precision": "bf16", + "num_cpu_threads_per_process": 1, + "pretrained_model_name_or_path": flux_models["transformer"], + "clip_l": flux_models["clip_l"], + "t5xxl": flux_models["t5"], + "ae": flux_models["vae"], + "save_model_as": "safetensors", + "persistent_data_loader_workers": False, + "max_data_loader_n_workers": 0, + "seed": 42, + "network_module": ".networks.lora_flux" if network_config is None else network_config["network_module"], + "dataset_config": dataset_toml, + "output_name": f"{output_name}_rank{kwargs.get('network_dim')}_{save_dtype}", + "loss_type": "l2", + "t5xxl_max_token_length": 512, + "alpha_mask": dataset["alpha_mask"], + "network_train_unet_only": True if train_text_encoder == 'disabled' else False, + "fp8_base_unet": True if "fp8" in train_text_encoder else False, + "disable_mmap_load_safetensors": False, + "network_args": None if network_config is None else network_config["network_args"], + } + attention_settings = { + "sdpa": {"mem_eff_attn": True, "xformers": False, "spda": True}, + "xformers": {"mem_eff_attn": True, "xformers": True, "spda": False} + } + config_dict.update(attention_settings.get(attention_mode, {})) + + gradient_dtype_settings = { + "fp16": {"full_fp16": True, "full_bf16": False, "mixed_precision": "fp16"}, + "bf16": {"full_bf16": True, "full_fp16": False, "mixed_precision": "bf16"} + } + config_dict.update(gradient_dtype_settings.get(gradient_dtype, {})) + + if train_text_encoder != 'disabled': + if T5_lr != "NaN": + config_dict["text_encoder_lr"] = clip_l_lr + if T5_lr != "NaN": + config_dict["text_encoder_lr"] = [clip_l_lr, T5_lr] + + if gradient_checkpointing == "disabled": + config_dict["gradient_checkpointing"] = False + elif gradient_checkpointing == "enabled_with_cpu_offloading": + config_dict["gradient_checkpointing"] = True + config_dict["cpu_offload_checkpointing"] = True + else: + config_dict["gradient_checkpointing"] = True + + if flux_models["lora_path"]: + config_dict["network_weights"] = flux_models["lora_path"] + + config_dict.update(kwargs) + config_dict.update(optimizer_settings) + + if loss_args: + config_dict.update(loss_args) + + if resume_args: + config_dict.update(resume_args) + + for key, value in config_dict.items(): + setattr(args, key, value) + + #network args + additional_network_args = [] + + if "T5" in train_text_encoder: + additional_network_args.append("train_t5xxl=True") + + if block_args: + additional_network_args.append(block_args["include"]) + + # Handle network_args in args Namespace + if hasattr(args, 'network_args') and isinstance(args.network_args, list): + args.network_args.extend(additional_network_args) + else: + setattr(args, 'network_args', additional_network_args) + + saved_args_file_path = os.path.join(output_dir, f"{output_name}_args.json") + with open(saved_args_file_path, 'w') as f: + json.dump(vars(args), f, indent=4) + + #workflow saving + metadata = {} + if extra_pnginfo is not None: + metadata.update(extra_pnginfo["workflow"]) + + saved_workflow_file_path = os.path.join(output_dir, f"{output_name}_workflow.json") + with open(saved_workflow_file_path, 'w') as f: + json.dump(metadata, f, indent=4) + + #pass args to kohya and initialize trainer + with torch.inference_mode(False): + network_trainer = FluxNetworkTrainer() + training_loop = network_trainer.init_train(args) + + epochs_count = network_trainer.num_train_epochs + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, epochs_count, args) + +class InitFluxTraining: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "flux_models": ("TRAIN_FLUX_MODELS",), + "dataset": ("JSON",), + "optimizer_settings": ("ARGS",), + "output_name": ("STRING", {"default": "flux", "multiline": False}), + "output_dir": ("STRING", {"default": "flux_trainer_output", "multiline": False, "tooltip": "path to dataset, root is the 'ComfyUI' folder, with windows portable 'ComfyUI_windows_portable'"}), + "learning_rate": ("FLOAT", {"default": 4e-6, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "learning rate"}), + "max_train_steps": ("INT", {"default": 1500, "min": 1, "max": 100000, "step": 1, "tooltip": "max number of training steps"}), + "apply_t5_attn_mask": ("BOOLEAN", {"default": True, "tooltip": "apply t5 attention mask"}), + "t5xxl_max_token_length": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 8, "tooltip": "dev and LibreFlux uses 512, schnell 256"}), + "cache_latents": (["disk", "memory", "disabled"], {"tooltip": "caches text encoder outputs"}), + "cache_text_encoder_outputs": (["disk", "memory", "disabled"], {"tooltip": "caches text encoder outputs"}), + "weighting_scheme": (["logit_normal", "sigma_sqrt", "mode", "cosmap", "none"],), + "logit_mean": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "mean to use when using the logit_normal weighting scheme"}), + "logit_std": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01,"tooltip": "std to use when using the logit_normal weighting scheme"}), + "mode_scale": ("FLOAT", {"default": 1.29, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Scale of mode weighting scheme. Only effective when using the mode as the weighting_scheme"}), + "loss_type": (["l1", "l2", "huber", "smooth_l1"], {"default": "l2", "tooltip": "loss type"}), + "timestep_sampling": (["sigmoid", "uniform", "sigma", "shift", "flux_shift"], {"tooltip": "Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid (recommend value of 3.1582 for discrete_flow_shift)"}), + "sigmoid_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1, "tooltip": "Scale factor for sigmoid timestep sampling (only used when timestep-sampling is sigmoid"}), + "model_prediction_type": (["raw", "additive", "sigma_scaled"], {"tooltip": "How to interpret and process the model prediction: raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)"}), + "cpu_offload_checkpointing": ("BOOLEAN", {"default": True, "tooltip": "offload the gradient checkpointing to CPU. This reduces VRAM usage for about 2GB"}), + "optimizer_fusing": (['fused_backward_pass', 'blockwise_fused_optimizers'], {"tooltip": "reduces memory use"}), + "blocks_to_swap": ("INT", {"default": 0, "min": 0, "max": 100, "step": 1, "tooltip": "Sets the number of blocks (~640MB) to swap during the forward and backward passes, increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."}), + "guidance_scale": ("FLOAT", {"default": 1.0, "min": 1.0, "max": 32.0, "step": 0.01, "tooltip": "guidance scale"}), + "discrete_flow_shift": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.0001, "tooltip": "for the Euler Discrete Scheduler, default is 3.0"}), + "highvram": ("BOOLEAN", {"default": False, "tooltip": "memory mode"}), + "fp8_base": ("BOOLEAN", {"default": False, "tooltip": "use fp8 for base model"}), + "gradient_dtype": (["fp32", "fp16", "bf16"], {"default": "bf16", "tooltip": "to use the full fp16/bf16 training"}), + "save_dtype": (["fp32", "fp16", "bf16", "fp8_e4m3fn"], {"default": "bf16", "tooltip": "the dtype to save checkpoints as"}), + "attention_mode": (["sdpa", "xformers", "disabled"], {"default": "sdpa", "tooltip": "memory efficient attention mode"}), + "sample_prompts": ("STRING", {"multiline": True, "default": "illustration of a kitten | photograph of a turtle", "tooltip": "validation sample prompts, for multiple prompts, separate by `|`"}), + }, + "optional": { + "additional_args": ("STRING", {"multiline": True, "default": "", "tooltip": "additional args to pass to the training command"}), + "resume_args": ("ARGS", {"default": "", "tooltip": "resume args to pass to the training command"}), + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "INT", "KOHYA_ARGS") + RETURN_NAMES = ("network_trainer", "epochs_count", "args") + FUNCTION = "init_training" + CATEGORY = "FluxTrainer" + + def init_training(self, flux_models, optimizer_settings, dataset, sample_prompts, output_name, + attention_mode, gradient_dtype, save_dtype, optimizer_fusing, additional_args=None, resume_args=None, **kwargs,): + mm.soft_empty_cache() + + output_dir = os.path.abspath(kwargs.get("output_dir")) + os.makedirs(output_dir, exist_ok=True) + + total, used, free = shutil.disk_usage(output_dir) + required_free_space = 25 * (2**30) + if free <= required_free_space: + raise ValueError(f"Most likely insufficient disk space to complete training. Required: {required_free_space/2**30}GB. Available: {free/2**30}GB") + + dataset_config = dataset["datasets"] + dataset_toml = toml.dumps(json.loads(dataset_config)) + + parser = train_setup_parser() + flux_train_utils.add_flux_train_arguments(parser) + + if additional_args is not None: + print(f"additional_args: {additional_args}") + args, _ = parser.parse_known_args(args=shlex.split(additional_args)) + else: + args, _ = parser.parse_known_args() + + if kwargs.get("cache_latents") == "memory": + kwargs["cache_latents"] = True + kwargs["cache_latents_to_disk"] = False + elif kwargs.get("cache_latents") == "disk": + kwargs["cache_latents"] = True + kwargs["cache_latents_to_disk"] = True + kwargs["caption_dropout_rate"] = 0.0 + kwargs["shuffle_caption"] = False + kwargs["token_warmup_step"] = 0.0 + kwargs["caption_tag_dropout_rate"] = 0.0 + else: + kwargs["cache_latents"] = False + kwargs["cache_latents_to_disk"] = False + + if kwargs.get("cache_text_encoder_outputs") == "memory": + kwargs["cache_text_encoder_outputs"] = True + kwargs["cache_text_encoder_outputs_to_disk"] = False + elif kwargs.get("cache_text_encoder_outputs") == "disk": + kwargs["cache_text_encoder_outputs"] = True + kwargs["cache_text_encoder_outputs_to_disk"] = True + else: + kwargs["cache_text_encoder_outputs"] = False + kwargs["cache_text_encoder_outputs_to_disk"] = False + + if '|' in sample_prompts: + prompts = sample_prompts.split('|') + else: + prompts = [sample_prompts] + + config_dict = { + "sample_prompts": prompts, + "save_precision": save_dtype, + "mixed_precision": "bf16", + "num_cpu_threads_per_process": 1, + "pretrained_model_name_or_path": flux_models["transformer"], + "clip_l": flux_models["clip_l"], + "t5xxl": flux_models["t5"], + "ae": flux_models["vae"], + "save_model_as": "safetensors", + "persistent_data_loader_workers": False, + "max_data_loader_n_workers": 0, + "seed": 42, + "gradient_checkpointing": True, + "dataset_config": dataset_toml, + "output_name": f"{output_name}_{save_dtype}", + "mem_eff_save": True, + "disable_mmap_load_safetensors": True, + + } + optimizer_fusing_settings = { + "fused_backward_pass": {"fused_backward_pass": True, "blockwise_fused_optimizers": False}, + "blockwise_fused_optimizers": {"fused_backward_pass": False, "blockwise_fused_optimizers": True} + } + config_dict.update(optimizer_fusing_settings.get(optimizer_fusing, {})) + + attention_settings = { + "sdpa": {"mem_eff_attn": True, "xformers": False, "spda": True}, + "xformers": {"mem_eff_attn": True, "xformers": True, "spda": False} + } + config_dict.update(attention_settings.get(attention_mode, {})) + + gradient_dtype_settings = { + "fp16": {"full_fp16": True, "full_bf16": False, "mixed_precision": "fp16"}, + "bf16": {"full_bf16": True, "full_fp16": False, "mixed_precision": "bf16"} + } + config_dict.update(gradient_dtype_settings.get(gradient_dtype, {})) + + config_dict.update(kwargs) + config_dict.update(optimizer_settings) + + if resume_args: + config_dict.update(resume_args) + + for key, value in config_dict.items(): + setattr(args, key, value) + + with torch.inference_mode(False): + network_trainer = FluxTrainer() + training_loop = network_trainer.init_train(args) + + epochs_count = network_trainer.num_train_epochs + + + saved_args_file_path = os.path.join(output_dir, f"{output_name}_args.json") + with open(saved_args_file_path, 'w') as f: + json.dump(vars(args), f, indent=4) + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, epochs_count, args) + +class InitFluxTrainingFromPreset: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "flux_models": ("TRAIN_FLUX_MODELS",), + "dataset_settings": ("TOML_DATASET",), + "preset_args": ("KOHYA_ARGS",), + "output_name": ("STRING", {"default": "flux", "multiline": False}), + "output_dir": ("STRING", {"default": "flux_trainer_output", "multiline": False, "tooltip": "output directory, root is ComfyUI folder"}), + "sample_prompts": ("STRING", {"multiline": True, "default": "illustration of a kitten | photograph of a turtle", "tooltip": "validation sample prompts, for multiple prompts, separate by `|`"}), + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "INT", "STRING", "KOHYA_ARGS") + RETURN_NAMES = ("network_trainer", "epochs_count", "output_path", "args") + FUNCTION = "init_training" + CATEGORY = "FluxTrainer" + + def init_training(self, flux_models, dataset_settings, sample_prompts, output_name, preset_args, **kwargs,): + mm.soft_empty_cache() + + dataset = dataset_settings["dataset"] + dataset_repeats = dataset_settings["repeats"] + + parser = train_setup_parser() + args, _ = parser.parse_known_args() + for key, value in vars(preset_args).items(): + setattr(args, key, value) + + output_dir = os.path.join(script_directory, "output") + if '|' in sample_prompts: + prompts = sample_prompts.split('|') + else: + prompts = [sample_prompts] + + width, height = toml.loads(dataset)["datasets"][0]["resolution"] + config_dict = { + "sample_prompts": prompts, + "dataset_repeats": dataset_repeats, + "num_cpu_threads_per_process": 1, + "pretrained_model_name_or_path": flux_models["transformer"], + "clip_l": flux_models["clip_l"], + "t5xxl": flux_models["t5"], + "ae": flux_models["vae"], + "save_model_as": "safetensors", + "persistent_data_loader_workers": False, + "max_data_loader_n_workers": 0, + "seed": 42, + "gradient_checkpointing": True, + "dataset_config": dataset, + "output_dir": output_dir, + "output_name": f"{output_name}_rank{kwargs.get('network_dim')}_{args.save_precision}", + "width" : int(width), + "height" : int(height), + + } + + config_dict.update(kwargs) + + for key, value in config_dict.items(): + setattr(args, key, value) + + with torch.inference_mode(False): + network_trainer = FluxNetworkTrainer() + training_loop = network_trainer.init_train(args) + + final_output_path = os.path.join(output_dir, output_name) + + epochs_count = network_trainer.num_train_epochs + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, epochs_count, final_output_path, args) + +class FluxTrainLoop: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "steps": ("INT", {"default": 1, "min": 1, "max": 10000, "step": 1, "tooltip": "the step point in training to validate/save"}), + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "INT",) + RETURN_NAMES = ("network_trainer", "steps",) + FUNCTION = "train" + CATEGORY = "FluxTrainer" + + def train(self, network_trainer, steps): + with torch.inference_mode(False): + training_loop = network_trainer["training_loop"] + network_trainer = network_trainer["network_trainer"] + initial_global_step = network_trainer.global_step + + target_global_step = network_trainer.global_step + steps + comfy_pbar = comfy.utils.ProgressBar(steps) + network_trainer.comfy_pbar = comfy_pbar + + network_trainer.optimizer_train_fn() + + while network_trainer.global_step < target_global_step: + steps_done = training_loop( + break_at_steps = target_global_step, + epoch = network_trainer.current_epoch.value, + ) + #pbar.update(steps_done) + + # Also break if the global steps have reached the max train steps + if network_trainer.global_step >= network_trainer.args.max_train_steps: + break + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, network_trainer.global_step) + +class FluxTrainAndValidateLoop: + @classmethod + def INPUT_TYPES(cls): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "validate_at_steps": ("INT", {"default": 250, "min": 1, "max": 10000, "step": 1, "tooltip": "the step point in training to validate/save"}), + "save_at_steps": ("INT", {"default": 250, "min": 1, "max": 10000, "step": 1, "tooltip": "the step point in training to validate/save"}), + }, + "optional": { + "validation_settings": ("VALSETTINGS",), + } + } + + RETURN_TYPES = ("NETWORKTRAINER", "INT",) + RETURN_NAMES = ("network_trainer", "steps",) + FUNCTION = "train" + CATEGORY = "FluxTrainer" + + def train(self, network_trainer, validate_at_steps, save_at_steps, validation_settings=None): + with torch.inference_mode(False): + training_loop = network_trainer["training_loop"] + network_trainer = network_trainer["network_trainer"] + + target_global_step = network_trainer.args.max_train_steps + comfy_pbar = comfy.utils.ProgressBar(target_global_step) + network_trainer.comfy_pbar = comfy_pbar + + network_trainer.optimizer_train_fn() + + while network_trainer.global_step < target_global_step: + next_validate_step = ((network_trainer.global_step // validate_at_steps) + 1) * validate_at_steps + next_save_step = ((network_trainer.global_step // save_at_steps) + 1) * save_at_steps + + steps_done = training_loop( + break_at_steps=min(next_validate_step, next_save_step), + epoch=network_trainer.current_epoch.value, + ) + + # Check if we need to validate + if network_trainer.global_step % validate_at_steps == 0: + self.validate(network_trainer, validation_settings) + + # Check if we need to save + if network_trainer.global_step % save_at_steps == 0: + self.save(network_trainer) + + # Also break if the global steps have reached the max train steps + if network_trainer.global_step >= network_trainer.args.max_train_steps: + break + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, network_trainer.global_step) + + def validate(self, network_trainer, validation_settings=None): + params = ( + network_trainer.current_epoch.value, + network_trainer.global_step, + validation_settings + ) + network_trainer.optimizer_eval_fn() + image_tensors = network_trainer.sample_images(*params) + network_trainer.optimizer_train_fn() + print("Validating at step:", network_trainer.global_step) + + def save(self, network_trainer): + ckpt_name = train_util.get_step_ckpt_name(network_trainer.args, "." + network_trainer.args.save_model_as, network_trainer.global_step) + network_trainer.optimizer_eval_fn() + network_trainer.save_model(ckpt_name, network_trainer.accelerator.unwrap_model(network_trainer.network), network_trainer.global_step, network_trainer.current_epoch.value + 1) + network_trainer.optimizer_train_fn() + print("Saving at step:", network_trainer.global_step) + +class FluxTrainSave: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "save_state": ("BOOLEAN", {"default": False, "tooltip": "save the whole model state as well"}), + "copy_to_comfy_lora_folder": ("BOOLEAN", {"default": False, "tooltip": "copy the lora model to the comfy lora folder"}), + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "STRING", "INT",) + RETURN_NAMES = ("network_trainer","lora_path", "steps",) + FUNCTION = "save" + CATEGORY = "FluxTrainer" + + def save(self, network_trainer, save_state, copy_to_comfy_lora_folder): + import shutil + with torch.inference_mode(False): + trainer = network_trainer["network_trainer"] + global_step = trainer.global_step + + ckpt_name = train_util.get_step_ckpt_name(trainer.args, "." + trainer.args.save_model_as, global_step) + trainer.save_model(ckpt_name, trainer.accelerator.unwrap_model(trainer.network), global_step, trainer.current_epoch.value + 1) + + remove_step_no = train_util.get_remove_step_no(trainer.args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(trainer.args, "." + trainer.args.save_model_as, remove_step_no) + trainer.remove_model(remove_ckpt_name) + + if save_state: + train_util.save_and_remove_state_stepwise(trainer.args, trainer.accelerator, global_step) + + lora_path = os.path.join(trainer.args.output_dir, ckpt_name) + if copy_to_comfy_lora_folder: + destination_dir = os.path.join(folder_paths.models_dir, "loras", "flux_trainer") + os.makedirs(destination_dir, exist_ok=True) + shutil.copy(lora_path, os.path.join(destination_dir, ckpt_name)) + + + return (network_trainer, lora_path, global_step) + +class FluxTrainSaveModel: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "copy_to_comfy_model_folder": ("BOOLEAN", {"default": False, "tooltip": "copy the lora model to the comfy lora folder"}), + "end_training": ("BOOLEAN", {"default": False, "tooltip": "end the training"}), + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "STRING", "INT",) + RETURN_NAMES = ("network_trainer","model_path", "steps",) + FUNCTION = "save" + CATEGORY = "FluxTrainer" + + def save(self, network_trainer, copy_to_comfy_model_folder, end_training): + import shutil + with torch.inference_mode(False): + trainer = network_trainer["network_trainer"] + global_step = trainer.global_step + + trainer.optimizer_eval_fn() + + ckpt_name = train_util.get_step_ckpt_name(trainer.args, "." + trainer.args.save_model_as, global_step) + flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + trainer.args, + False, + trainer.accelerator, + trainer.save_dtype, + trainer.current_epoch.value, + trainer.num_train_epochs, + global_step, + trainer.accelerator.unwrap_model(trainer.unet) + ) + + model_path = os.path.join(trainer.args.output_dir, ckpt_name) + if copy_to_comfy_model_folder: + shutil.copy(model_path, os.path.join(folder_paths.models_dir, "diffusion_models", "flux_trainer", ckpt_name)) + model_path = os.path.join(folder_paths.models_dir, "diffusion_models", "flux_trainer", ckpt_name) + if end_training: + trainer.accelerator.end_training() + + return (network_trainer, model_path, global_step) + +class FluxTrainEnd: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "save_state": ("BOOLEAN", {"default": True}), + }, + } + + RETURN_TYPES = ("STRING", "STRING", "STRING",) + RETURN_NAMES = ("lora_name", "metadata", "lora_path",) + FUNCTION = "endtrain" + CATEGORY = "FluxTrainer" + OUTPUT_NODE = True + + def endtrain(self, network_trainer, save_state): + with torch.inference_mode(False): + training_loop = network_trainer["training_loop"] + network_trainer = network_trainer["network_trainer"] + + network_trainer.metadata["ss_epoch"] = str(network_trainer.num_train_epochs) + network_trainer.metadata["ss_training_finished_at"] = str(time.time()) + + network = network_trainer.accelerator.unwrap_model(network_trainer.network) + + network_trainer.accelerator.end_training() + network_trainer.optimizer_eval_fn() + + if save_state: + train_util.save_state_on_train_end(network_trainer.args, network_trainer.accelerator) + + ckpt_name = train_util.get_last_ckpt_name(network_trainer.args, "." + network_trainer.args.save_model_as) + network_trainer.save_model(ckpt_name, network, network_trainer.global_step, network_trainer.num_train_epochs, force_sync_upload=True) + logger.info("model saved.") + + final_lora_name = str(network_trainer.args.output_name) + final_lora_path = os.path.join(network_trainer.args.output_dir, ckpt_name) + + # metadata + metadata = json.dumps(network_trainer.metadata, indent=2) + + training_loop = None + network_trainer = None + mm.soft_empty_cache() + + return (final_lora_name, metadata, final_lora_path) + +class FluxTrainResume: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "load_state_path": ("STRING", {"default": "", "multiline": True, "tooltip": "path to load state from"}), + "skip_until_initial_step" : ("BOOLEAN", {"default": False}), + }, + } + + RETURN_TYPES = ("ARGS", ) + RETURN_NAMES = ("resume_args", ) + FUNCTION = "resume" + CATEGORY = "FluxTrainer" + + def resume(self, load_state_path, skip_until_initial_step): + resume_args ={ + "resume": load_state_path, + "skip_until_initial_step": skip_until_initial_step + } + + return (resume_args, ) + +class FluxTrainBlockSelect: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "include": ("STRING", {"default": "lora_unet_single_blocks_20_linear2", "multiline": True, "tooltip": "blocks to include in the LoRA network, to select multiple blocks either input them as "}), + }, + } + + RETURN_TYPES = ("ARGS", ) + RETURN_NAMES = ("block_args", ) + FUNCTION = "block_select" + CATEGORY = "FluxTrainer" + + def block_select(self, include): + import re + + # Split the input string by commas to handle multiple ranges/blocks + elements = include.split(',') + + # Initialize a list to collect block names + blocks = [] + + # Pattern to find ranges like (10-20) + pattern = re.compile(r'\((\d+)-(\d+)\)') + + # Extract the prefix and suffix from the first element + prefix_suffix_pattern = re.compile(r'(.*)_blocks_(.*)') + + for element in elements: + element = element.strip() + match = prefix_suffix_pattern.match(element) + if match: + prefix = match.group(1) + "_blocks_" + suffix = match.group(2) + matches = pattern.findall(suffix) + if matches: + for start, end in matches: + # Generate block names for the range and add them to the list + blocks.extend([f"{prefix}{i}{suffix.replace(f'({start}-{end})', '', 1)}" for i in range(int(start), int(end) + 1)]) + else: + # If no range is found, add the block name directly + blocks.append(element) + else: + blocks.append(element) + + # Construct the `include` string + include_string = ','.join(blocks) + + block_args = { + "include": f"only_if_contains={include_string}", + } + + return (block_args, ) + +class FluxTrainValidationSettings: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "steps": ("INT", {"default": 20, "min": 1, "max": 256, "step": 1, "tooltip": "sampling steps"}), + "width": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 8, "tooltip": "image width"}), + "height": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 8, "tooltip": "image height"}), + "guidance_scale": ("FLOAT", {"default": 3.5, "min": 1.0, "max": 32.0, "step": 0.05, "tooltip": "guidance scale"}), + "seed": ("INT", {"default": 42,"min": 0, "max": 0xffffffffffffffff, "step": 1}), + "shift": ("BOOLEAN", {"default": True, "tooltip": "shift the schedule to favor high timesteps for higher signal images"}), + "base_shift": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}), + "max_shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 10.0, "step": 0.01}), + }, + } + + RETURN_TYPES = ("VALSETTINGS", ) + RETURN_NAMES = ("validation_settings", ) + FUNCTION = "set" + CATEGORY = "FluxTrainer" + + def set(self, **kwargs): + validation_settings = kwargs + print(validation_settings) + + return (validation_settings,) + +class FluxTrainValidate: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "network_trainer": ("NETWORKTRAINER",), + }, + "optional": { + "validation_settings": ("VALSETTINGS",), + } + } + + RETURN_TYPES = ("NETWORKTRAINER", "IMAGE",) + RETURN_NAMES = ("network_trainer", "validation_images",) + FUNCTION = "validate" + CATEGORY = "FluxTrainer" + + def validate(self, network_trainer, validation_settings=None): + training_loop = network_trainer["training_loop"] + network_trainer = network_trainer["network_trainer"] + + params = ( + network_trainer.current_epoch.value, + network_trainer.global_step, + validation_settings + ) + network_trainer.optimizer_eval_fn() + with torch.inference_mode(False): + image_tensors = network_trainer.sample_images(*params) + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, (0.5 * (image_tensors + 1.0)).cpu().float(),) + +class VisualizeLoss: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "plot_style": (plt.style.available,{"default": 'default', "tooltip": "matplotlib plot style"}), + "window_size": ("INT", {"default": 100, "min": 0, "max": 10000, "step": 1, "tooltip": "the window size of the moving average"}), + "normalize_y": ("BOOLEAN", {"default": True, "tooltip": "normalize the y-axis to 0"}), + "width": ("INT", {"default": 768, "min": 256, "max": 4096, "step": 2, "tooltip": "width of the plot in pixels"}), + "height": ("INT", {"default": 512, "min": 256, "max": 4096, "step": 2, "tooltip": "height of the plot in pixels"}), + "log_scale": ("BOOLEAN", {"default": False, "tooltip": "use log scale on the y-axis"}), + }, + } + + RETURN_TYPES = ("IMAGE", "FLOAT",) + RETURN_NAMES = ("plot", "loss_list",) + FUNCTION = "draw" + CATEGORY = "FluxTrainer" + + def draw(self, network_trainer, window_size, plot_style, normalize_y, width, height, log_scale): + import numpy as np + loss_values = network_trainer["network_trainer"].loss_recorder.global_loss_list + + # Apply moving average + def moving_average(values, window_size): + return np.convolve(values, np.ones(window_size) / window_size, mode='valid') + if window_size > 0: + loss_values = moving_average(loss_values, window_size) + + plt.style.use(plot_style) + + # Convert pixels to inches (assuming 100 pixels per inch) + width_inches = width / 100 + height_inches = height / 100 + + # Create a plot + fig, ax = plt.subplots(figsize=(width_inches, height_inches)) + ax.plot(loss_values, label='Training Loss') + ax.set_xlabel('Step') + ax.set_ylabel('Loss') + if normalize_y: + plt.ylim(bottom=0) + if log_scale: + ax.set_yscale('log') + ax.set_title('Training Loss Over Time') + ax.legend() + ax.grid(True) + + buf = io.BytesIO() + plt.savefig(buf, format='png') + plt.close(fig) + buf.seek(0) + + image = Image.open(buf).convert('RGB') + + image_tensor = transforms.ToTensor()(image) + image_tensor = image_tensor.unsqueeze(0).permute(0, 2, 3, 1).cpu().float() + + return image_tensor, loss_values, + +class FluxKohyaInferenceSampler: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "flux_models": ("TRAIN_FLUX_MODELS",), + "lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}), + "lora_method": (["apply", "merge"], {"tooltip": "whether to apply or merge the lora weights"}), + "steps": ("INT", {"default": 20, "min": 1, "max": 256, "step": 1, "tooltip": "sampling steps"}), + "width": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 8, "tooltip": "image width"}), + "height": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 8, "tooltip": "image height"}), + "guidance_scale": ("FLOAT", {"default": 3.5, "min": 1.0, "max": 32.0, "step": 0.05, "tooltip": "guidance scale"}), + "seed": ("INT", {"default": 42,"min": 0, "max": 0xffffffffffffffff, "step": 1}), + "use_fp8": ("BOOLEAN", {"default": True, "tooltip": "use fp8 weights"}), + "apply_t5_attn_mask": ("BOOLEAN", {"default": True, "tooltip": "use t5 attention mask"}), + "prompt": ("STRING", {"multiline": True, "default": "illustration of a kitten", "tooltip": "prompt"}), + + }, + } + + RETURN_TYPES = ("IMAGE", ) + RETURN_NAMES = ("image", ) + FUNCTION = "sample" + CATEGORY = "FluxTrainer" + + def sample(self, flux_models, lora_name, steps, width, height, guidance_scale, seed, prompt, use_fp8, lora_method, apply_t5_attn_mask): + + from .library import flux_utils as flux_utils + from .library import strategy_flux as strategy_flux + from .networks import lora_flux as lora_flux + from typing import List, Optional, Callable + from tqdm import tqdm + import einops + import math + import accelerate + import gc + + device = "cuda" + + + if use_fp8: + accelerator = accelerate.Accelerator(mixed_precision="bf16") + dtype = torch.float8_e4m3fn + else: + dtype = torch.float16 + accelerator = None + loading_device = "cpu" + ae_dtype = torch.bfloat16 + + pretrained_model_name_or_path = flux_models["transformer"] + clip_l = flux_models["clip_l"] + t5xxl = flux_models["t5"] + ae = flux_models["vae"] + lora_path = folder_paths.get_full_path("loras", lora_name) + + # load clip_l + logger.info(f"Loading clip_l from {clip_l}...") + clip_l = flux_utils.load_clip_l(clip_l, None, loading_device) + clip_l.eval() + + logger.info(f"Loading t5xxl from {t5xxl}...") + t5xxl = flux_utils.load_t5xxl(t5xxl, None, loading_device) + t5xxl.eval() + + if use_fp8: + clip_l = accelerator.prepare(clip_l) + t5xxl = accelerator.prepare(t5xxl) + + t5xxl_max_length = 512 + tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length) + encoding_strategy = strategy_flux.FluxTextEncodingStrategy() + + # DiT + model = flux_utils.load_flow_model("dev", pretrained_model_name_or_path, dtype, loading_device) + model.eval() + logger.info(f"Casting model to {dtype}") + model.to(dtype) # make sure model is dtype + if use_fp8: + model = accelerator.prepare(model) + + # AE + ae = flux_utils.load_ae("dev", ae, ae_dtype, loading_device) + ae.eval() + + + # LoRA + lora_models: List[lora_flux.LoRANetwork] = [] + multiplier = 1.0 + + lora_model, weights_sd = lora_flux.create_network_from_weights( + multiplier, lora_path, ae, [clip_l, t5xxl], model, None, True + ) + if lora_method == "merge": + lora_model.merge_to([clip_l, t5xxl], model, weights_sd) + elif lora_method == "apply": + lora_model.apply_to([clip_l, t5xxl], model) + info = lora_model.load_state_dict(weights_sd, strict=True) + logger.info(f"Loaded LoRA weights from {lora_name}: {info}") + lora_model.eval() + lora_model.to(device) + lora_models.append(lora_model) + + + packed_latent_height, packed_latent_width = math.ceil(height / 16), math.ceil(width / 16) + noise = torch.randn( + 1, + packed_latent_height * packed_latent_width, + 16 * 2 * 2, + device=device, + dtype=ae_dtype, + generator=torch.Generator(device=device).manual_seed(seed), + ) + + img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width) + + # prepare embeddings + logger.info("Encoding prompts...") + tokens_and_masks = tokenize_strategy.tokenize(prompt) + clip_l = clip_l.to(device) + t5xxl = t5xxl.to(device) + with torch.no_grad(): + if use_fp8: + clip_l.to(ae_dtype) + t5xxl.to(ae_dtype) + with accelerator.autocast(): + l_pooled, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, apply_t5_attn_mask + ) + else: + with torch.autocast(device_type=device.type, dtype=dtype): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + with torch.autocast(device_type=device.type, dtype=dtype): + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [None, t5xxl], tokens_and_masks, apply_t5_attn_mask + ) + # NaN check + if torch.isnan(l_pooled).any(): + raise ValueError("NaN in l_pooled") + + if torch.isnan(t5_out).any(): + raise ValueError("NaN in t5_out") + + + clip_l = clip_l.cpu() + t5xxl = t5xxl.cpu() + + gc.collect() + torch.cuda.empty_cache() + + # generate image + logger.info("Generating image...") + model = model.to(device) + print("MODEL DTYPE: ", model.dtype) + + img_ids = img_ids.to(device) + t5_attn_mask = t5_attn_mask.to(device) if apply_t5_attn_mask else None + def time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + + def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + + def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, + ) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + + def denoise( + model, + img: torch.Tensor, + img_ids: torch.Tensor, + txt: torch.Tensor, + txt_ids: torch.Tensor, + vec: torch.Tensor, + timesteps: list[float], + guidance: float = 4.0, + t5_attn_mask: Optional[torch.Tensor] = None, + ): + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + comfy_pbar = comfy.utils.ProgressBar(total=len(timesteps)) + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + img = img + (t_prev - t_curr) * pred + comfy_pbar.update(1) + + return img + def do_sample( + accelerator: Optional[accelerate.Accelerator], + model, + img: torch.Tensor, + img_ids: torch.Tensor, + l_pooled: torch.Tensor, + t5_out: torch.Tensor, + txt_ids: torch.Tensor, + num_steps: int, + guidance: float, + t5_attn_mask: Optional[torch.Tensor], + is_schnell: bool, + device: torch.device, + flux_dtype: torch.dtype, + ): + timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell) + print(timesteps) + + # denoise initial noise + if accelerator: + with accelerator.autocast(), torch.no_grad(): + x = denoise( + model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask + ) + else: + with torch.autocast(device_type=device.type, dtype=flux_dtype): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + with torch.autocast(device_type=device.type, dtype=flux_dtype): + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [None, t5xxl], tokens_and_masks, apply_t5_attn_mask + ) + + return x + + x = do_sample(accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance_scale, t5_attn_mask, False, device, dtype) + + model = model.cpu() + gc.collect() + torch.cuda.empty_cache() + + # unpack + x = x.float() + x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2) + + # decode + logger.info("Decoding image...") + ae = ae.to(device) + with torch.no_grad(): + if use_fp8: + with accelerator.autocast(): + x = ae.decode(x) + else: + with torch.autocast(device_type=device.type, dtype=ae_dtype): + x = ae.decode(x) + + ae = ae.cpu() + + x = x.clamp(-1, 1) + x = x.permute(0, 2, 3, 1) + + return ((0.5 * (x + 1.0)).cpu().float(),) + +class UploadToHuggingFace: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "network_trainer": ("NETWORKTRAINER",), + "source_path": ("STRING", {"default": ""}), + "repo_id": ("STRING",{"default": ""}), + "revision": ("STRING", {"default": ""}), + "private": ("BOOLEAN", {"default": True, "tooltip": "If creating a new repo, leave it private"}), + }, + "optional": { + "token": ("STRING", {"default": "","tooltip":"DO NOT LEAVE IN THE NODE or it might save in metadata, can also use the hf_token.json"}), + } + } + + RETURN_TYPES = ("NETWORKTRAINER", "STRING",) + RETURN_NAMES = ("network_trainer","status",) + FUNCTION = "upload" + CATEGORY = "FluxTrainer" + + def upload(self, source_path, network_trainer, repo_id, private, revision, token=""): + with torch.inference_mode(False): + from huggingface_hub import HfApi + + if not token: + with open(os.path.join(script_directory, "hf_token.json"), "r") as file: + token_data = json.load(file) + token = token_data["hf_token"] + print(token) + + # Save metadata to a JSON file + directory_path = os.path.dirname(os.path.dirname(source_path)) + file_name = os.path.basename(source_path) + + metadata = network_trainer["network_trainer"].metadata + metadata_file_path = os.path.join(directory_path, "metadata.json") + with open(metadata_file_path, 'w') as f: + json.dump(metadata, f, indent=4) + + repo_type = None + api = HfApi(token=token) + + try: + api.repo_info( + repo_id=repo_id, + revision=revision if revision != "" else None, + repo_type=repo_type) + repo_exists = True + logger.info(f"Repository {repo_id} exists.") + except Exception as e: # Catching a more specific exception would be better if you know what to expect + repo_exists = False + logger.error(f"Repository {repo_id} does not exist. Exception: {e}") + + if not repo_exists: + try: + api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private) + except Exception as e: # Checked for RepositoryNotFoundError, but other exceptions could be problematic + logger.error("===========================================") + logger.error(f"failed to create HuggingFace repo: {e}") + logger.error("===========================================") + + is_folder = (type(source_path) == str and os.path.isdir(source_path)) or (isinstance(source_path, Path) and source_path.is_dir()) + print(source_path, is_folder) + + try: + if is_folder: + api.upload_folder( + repo_id=repo_id, + repo_type=repo_type, + folder_path=source_path, + path_in_repo=file_name, + ) + else: + api.upload_file( + repo_id=repo_id, + repo_type=repo_type, + path_or_fileobj=source_path, + path_in_repo=file_name, + ) + # Upload the metadata file separately if it's not a folder upload + if not is_folder: + api.upload_file( + repo_id=repo_id, + repo_type=repo_type, + path_or_fileobj=str(metadata_file_path), + path_in_repo='metadata.json', + ) + status = "Uploaded to HuggingFace succesfully" + except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので + logger.error("===========================================") + logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}") + logger.error("===========================================") + status = f"Failed to upload to HuggingFace {e}" + + return (network_trainer, status,) + +class ExtractFluxLoRA: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "original_model": (folder_paths.get_filename_list("unet"), ), + "finetuned_model": (folder_paths.get_filename_list("unet"), ), + "output_path": ("STRING", {"default": f"{str(os.path.join(folder_paths.models_dir, 'loras', 'Flux'))}"}), + "dim": ("INT", {"default": 4, "min": 2, "max": 1024, "step": 2, "tooltip": "LoRA rank"}), + "save_dtype": (["fp32", "fp16", "bf16", "fp8_e4m3fn", "fp8_e5m2"], {"default": "bf16", "tooltip": "the dtype to save the LoRA as"}), + "load_device": (["cpu", "cuda"], {"default": "cuda", "tooltip": "the device to load the model to"}), + "store_device": (["cpu", "cuda"], {"default": "cpu", "tooltip": "the device to store the LoRA as"}), + "clamp_quantile": ("FLOAT", {"default": 0.99, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "clamp quantile"}), + "metadata": ("BOOLEAN", {"default": True, "tooltip": "build metadata"}), + "mem_eff_safe_open": ("BOOLEAN", {"default": False, "tooltip": "memory efficient loading"}), + }, + } + + RETURN_TYPES = ("STRING", ) + RETURN_NAMES = ("output_path",) + FUNCTION = "extract" + CATEGORY = "FluxTrainer" + + def extract(self, original_model, finetuned_model, output_path, dim, save_dtype, load_device, store_device, clamp_quantile, metadata, mem_eff_safe_open): + from .flux_extract_lora import svd + transformer_path = folder_paths.get_full_path("unet", original_model) + finetuned_model_path = folder_paths.get_full_path("unet", finetuned_model) + outpath = svd( + model_org = transformer_path, + model_tuned = finetuned_model_path, + save_to = os.path.join(output_path, f"{finetuned_model.replace('.safetensors', '')}_extracted_lora_rank_{dim}-{save_dtype}.safetensors"), + dim = dim, + device = load_device, + store_device = store_device, + save_precision = save_dtype, + clamp_quantile = clamp_quantile, + no_metadata = not metadata, + mem_eff_safe_open = mem_eff_safe_open + ) + + return (outpath,) + +NODE_CLASS_MAPPINGS = { + "InitFluxLoRATraining": InitFluxLoRATraining, + "InitFluxTraining": InitFluxTraining, + "FluxTrainModelSelect": FluxTrainModelSelect, + "TrainDatasetGeneralConfig": TrainDatasetGeneralConfig, + "TrainDatasetAdd": TrainDatasetAdd, + "FluxTrainLoop": FluxTrainLoop, + "VisualizeLoss": VisualizeLoss, + "FluxTrainValidate": FluxTrainValidate, + "FluxTrainValidationSettings": FluxTrainValidationSettings, + "FluxTrainEnd": FluxTrainEnd, + "FluxTrainSave": FluxTrainSave, + "FluxKohyaInferenceSampler": FluxKohyaInferenceSampler, + "UploadToHuggingFace": UploadToHuggingFace, + "OptimizerConfig": OptimizerConfig, + "OptimizerConfigAdafactor": OptimizerConfigAdafactor, + "FluxTrainSaveModel": FluxTrainSaveModel, + "ExtractFluxLoRA": ExtractFluxLoRA, + "OptimizerConfigProdigy": OptimizerConfigProdigy, + "FluxTrainResume": FluxTrainResume, + "FluxTrainBlockSelect": FluxTrainBlockSelect, + "TrainDatasetRegularization": TrainDatasetRegularization, + "FluxTrainAndValidateLoop": FluxTrainAndValidateLoop, + "OptimizerConfigProdigyPlusScheduleFree": OptimizerConfigProdigyPlusScheduleFree, + "FluxTrainerLossConfig": FluxTrainerLossConfig, + "TrainNetworkConfig": TrainNetworkConfig, +} +NODE_DISPLAY_NAME_MAPPINGS = { + "InitFluxLoRATraining": "Init Flux LoRA Training", + "InitFluxTraining": "Init Flux Training", + "FluxTrainModelSelect": "FluxTrain ModelSelect", + "TrainDatasetGeneralConfig": "TrainDatasetGeneralConfig", + "TrainDatasetAdd": "TrainDatasetAdd", + "FluxTrainLoop": "Flux Train Loop", + "VisualizeLoss": "Visualize Loss", + "FluxTrainValidate": "Flux Train Validate", + "FluxTrainValidationSettings": "Flux Train Validation Settings", + "FluxTrainEnd": "Flux LoRA Train End", + "FluxTrainSave": "Flux Train Save LoRA", + "FluxKohyaInferenceSampler": "Flux Kohya Inference Sampler", + "UploadToHuggingFace": "Upload To HuggingFace", + "OptimizerConfig": "Optimizer Config", + "OptimizerConfigAdafactor": "Optimizer Config Adafactor", + "FluxTrainSaveModel": "Flux Train Save Model", + "ExtractFluxLoRA": "Extract Flux LoRA", + "OptimizerConfigProdigy": "Optimizer Config Prodigy", + "FluxTrainResume": "Flux Train Resume", + "FluxTrainBlockSelect": "Flux Train Block Select", + "TrainDatasetRegularization": "Train Dataset Regularization", + "FluxTrainAndValidateLoop": "Flux Train And Validate Loop", + "OptimizerConfigProdigyPlusScheduleFree": "Optimizer Config ProdigyPlusScheduleFree", + "FluxTrainerLossConfig": "Flux Trainer Loss Config", + "TrainNetworkConfig": "Train Network Config", +} diff --git a/nodes_sd3.py b/nodes_sd3.py new file mode 100644 index 0000000000000000000000000000000000000000..512d0af11f4b48febacfd1cc932cfd32ba057635 --- /dev/null +++ b/nodes_sd3.py @@ -0,0 +1,467 @@ +import os +import torch + +import folder_paths +import comfy.model_management as mm +import comfy.utils +import toml +import json +import time +import shutil +import shlex + +script_directory = os.path.dirname(os.path.abspath(__file__)) + +from .sd3_train_network import Sd3NetworkTrainer +from .library import sd3_train_utils as sd3_train_utils +from .library.device_utils import init_ipex +init_ipex() + +from .library import train_util +from .train_network import setup_parser as train_network_setup_parser + +import logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +class SD3ModelSelect: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "transformer": (folder_paths.get_filename_list("checkpoints"), ), + "clip_l": (folder_paths.get_filename_list("clip"), ), + "clip_g": (folder_paths.get_filename_list("clip"), ), + "t5": (folder_paths.get_filename_list("clip"), ), + }, + "optional": { + "lora_path": ("STRING",{"multiline": True, "forceInput": True, "default": "", "tooltip": "pre-trained LoRA path to load (network_weights)"}), + } + } + + RETURN_TYPES = ("TRAIN_SD3_MODELS",) + RETURN_NAMES = ("sd3_models",) + FUNCTION = "loadmodel" + CATEGORY = "FluxTrainer/SD3" + + def loadmodel(self, transformer, clip_l, clip_g, t5, lora_path=""): + + transformer_path = folder_paths.get_full_path("checkpoints", transformer) + clip_l_path = folder_paths.get_full_path("clip", clip_l) + clip_g_path = folder_paths.get_full_path("clip", clip_g) + t5_path = folder_paths.get_full_path("clip", t5) + + sd3_models = { + "transformer": transformer_path, + "clip_l": clip_l_path, + "clip_g": clip_g_path, + "t5": t5_path, + "lora_path": lora_path + } + + return (sd3_models,) + +class InitSD3LoRATraining: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "sd3_models": ("TRAIN_SD3_MODELS",), + "dataset": ("JSON",), + "optimizer_settings": ("ARGS",), + "output_name": ("STRING", {"default": "sd35_lora", "multiline": False}), + "output_dir": ("STRING", {"default": "sd35_trainer_output", "multiline": False, "tooltip": "path to dataset, root is the 'ComfyUI' folder, with windows portable 'ComfyUI_windows_portable'"}), + "network_dim": ("INT", {"default": 16, "min": 1, "max": 2048, "step": 1, "tooltip": "network dim"}), + "network_alpha": ("FLOAT", {"default": 16, "min": 0.0, "max": 2048.0, "step": 0.01, "tooltip": "network alpha"}), + "learning_rate": ("FLOAT", {"default": 1e-4, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "learning rate"}), + "max_train_steps": ("INT", {"default": 1500, "min": 1, "max": 100000, "step": 1, "tooltip": "max number of training steps"}), + "cache_latents": (["disk", "memory", "disabled"], {"tooltip": "caches text encoder outputs"}), + "cache_text_encoder_outputs": (["disk", "memory", "disabled"], {"tooltip": "caches text encoder outputs"}), + "training_shift ": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.0001, "tooltip": "shift value for the training distribution of timesteps"}), + "highvram": ("BOOLEAN", {"default": False, "tooltip": "memory mode"}), + "blocks_to_swap": ("INT", {"default": 0, "min": 0, "max": 100, "step": 1, "tooltip": "option for memory use reduction. The maximum number of blocks that can be swapped is 36 for SD3.5L and 22 for SD3.5M"}), + "fp8_base": ("BOOLEAN", {"default": False, "tooltip": "use fp8 for base model"}), + "gradient_dtype": (["fp32", "fp16", "bf16"], {"default": "fp32", "tooltip": "the actual dtype training uses"}), + "save_dtype": (["fp32", "fp16", "bf16", "fp8_e4m3fn", "fp8_e5m2"], {"default": "bf16", "tooltip": "the dtype to save checkpoints as"}), + "attention_mode": (["sdpa", "xformers", "disabled"], {"default": "sdpa", "tooltip": "memory efficient attention mode"}), + "train_text_encoder": (['disabled', 'clip_l', 'clip_l_fp8', 'clip_l+T5', 'clip_l+T5_fp8'], {"default": 'disabled', "tooltip": "also train the selected text encoders using specified dtype, T5 can not be trained without clip_l"}), + "clip_l_lr": ("FLOAT", {"default": 0, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "text encoder learning rate"}), + "clip_g_lr": ("FLOAT", {"default": 0, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "text encoder learning rate"}), + "T5_lr": ("FLOAT", {"default": 0, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "text encoder learning rate"}), + "sample_prompts": ("STRING", {"multiline": True, "default": "illustration of a kitten | photograph of a turtle", "tooltip": "validation sample prompts, for multiple prompts, separate by `|`"}), + "gradient_checkpointing": (["enabled", "disabled"], {"default": "enabled", "tooltip": "use gradient checkpointing"}), + }, + "optional": { + "additional_args": ("STRING", {"multiline": True, "default": "", "tooltip": "additional args to pass to the training command"}), + "resume_args": ("ARGS", {"default": "", "tooltip": "resume args to pass to the training command"}), + "block_args": ("ARGS", {"default": "", "tooltip": "limit the blocks used in the LoRA"}), + "loss_args": ("ARGS", {"default": "", "tooltip": "loss args"}), + }, + "hidden": { + "prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO" + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "INT", "KOHYA_ARGS",) + RETURN_NAMES = ("network_trainer", "epochs_count", "args",) + FUNCTION = "init_training" + CATEGORY = "FluxTrainer/SD3" + + def init_training(self, sd3_models, dataset, optimizer_settings, sample_prompts, output_name, attention_mode, + gradient_dtype, save_dtype, additional_args=None, resume_args=None, train_text_encoder='disabled', + block_args=None, gradient_checkpointing="enabled", prompt=None, extra_pnginfo=None, clip_l_lr=0, clip_g_lr=0, T5_lr=0, loss_args=None, **kwargs): + mm.soft_empty_cache() + + output_dir = os.path.abspath(kwargs.get("output_dir")) + os.makedirs(output_dir, exist_ok=True) + + total, used, free = shutil.disk_usage(output_dir) + + required_free_space = 2 * (2**30) + if free <= required_free_space: + raise ValueError(f"Insufficient disk space. Required: {required_free_space/2**30}GB. Available: {free/2**30}GB") + + dataset_config = dataset["datasets"] + dataset_toml = toml.dumps(json.loads(dataset_config)) + + parser = train_network_setup_parser() + sd3_train_utils.add_sd3_training_arguments(parser) + if additional_args is not None: + print(f"additional_args: {additional_args}") + args, _ = parser.parse_known_args(args=shlex.split(additional_args)) + else: + args, _ = parser.parse_known_args() + + if kwargs.get("cache_latents") == "memory": + kwargs["cache_latents"] = True + kwargs["cache_latents_to_disk"] = False + elif kwargs.get("cache_latents") == "disk": + kwargs["cache_latents"] = True + kwargs["cache_latents_to_disk"] = True + kwargs["caption_dropout_rate"] = 0.0 + kwargs["shuffle_caption"] = False + kwargs["token_warmup_step"] = 0.0 + kwargs["caption_tag_dropout_rate"] = 0.0 + else: + kwargs["cache_latents"] = False + kwargs["cache_latents_to_disk"] = False + + if kwargs.get("cache_text_encoder_outputs") == "memory": + kwargs["cache_text_encoder_outputs"] = True + kwargs["cache_text_encoder_outputs_to_disk"] = False + elif kwargs.get("cache_text_encoder_outputs") == "disk": + kwargs["cache_text_encoder_outputs"] = True + kwargs["cache_text_encoder_outputs_to_disk"] = True + else: + kwargs["cache_text_encoder_outputs"] = False + kwargs["cache_text_encoder_outputs_to_disk"] = False + + if '|' in sample_prompts: + prompts = sample_prompts.split('|') + else: + prompts = [sample_prompts] + + config_dict = { + "sample_prompts": prompts, + "save_precision": save_dtype, + "mixed_precision": "bf16", + "num_cpu_threads_per_process": 1, + "pretrained_model_name_or_path": sd3_models["transformer"], + "clip_l": sd3_models["clip_l"], + "clip_g": sd3_models["clip_g"], + "t5xxl": sd3_models["t5"], + "save_model_as": "safetensors", + "persistent_data_loader_workers": False, + "max_data_loader_n_workers": 0, + "seed": 42, + "network_module": ".networks.lora_sd3", + "dataset_config": dataset_toml, + "output_name": f"{output_name}_rank{kwargs.get('network_dim')}_{save_dtype}", + "loss_type": "l2", + "t5xxl_max_token_length": 512, + "alpha_mask": dataset["alpha_mask"], + "network_train_unet_only": True if train_text_encoder == 'disabled' else False, + "fp8_base_unet": True if "fp8" in train_text_encoder else False, + "disable_mmap_load_safetensors": False, + } + attention_settings = { + "sdpa": {"mem_eff_attn": True, "xformers": False, "spda": True}, + "xformers": {"mem_eff_attn": True, "xformers": True, "spda": False} + } + config_dict.update(attention_settings.get(attention_mode, {})) + + gradient_dtype_settings = { + "fp16": {"full_fp16": True, "full_bf16": False, "mixed_precision": "fp16"}, + "bf16": {"full_bf16": True, "full_fp16": False, "mixed_precision": "bf16"} + } + config_dict.update(gradient_dtype_settings.get(gradient_dtype, {})) + + if train_text_encoder != 'disabled': + config_dict["text_encoder_lr"] = [clip_l_lr, clip_g_lr, T5_lr] + + #network args + additional_network_args = [] + + if "T5" in train_text_encoder: + additional_network_args.append("train_t5xxl=True") + + if block_args: + additional_network_args.append(block_args["include"]) + + # Handle network_args in args Namespace + if hasattr(args, 'network_args') and isinstance(args.network_args, list): + args.network_args.extend(additional_network_args) + else: + setattr(args, 'network_args', additional_network_args) + + if gradient_checkpointing == "disabled": + config_dict["gradient_checkpointing"] = False + elif gradient_checkpointing == "enabled_with_cpu_offloading": + config_dict["gradient_checkpointing"] = True + config_dict["cpu_offload_checkpointing"] = True + else: + config_dict["gradient_checkpointing"] = True + + if sd3_models["lora_path"]: + config_dict["network_weights"] = sd3_models["lora_path"] + + config_dict.update(kwargs) + config_dict.update(optimizer_settings) + + if loss_args: + config_dict.update(loss_args) + + if resume_args: + config_dict.update(resume_args) + + for key, value in config_dict.items(): + setattr(args, key, value) + + saved_args_file_path = os.path.join(output_dir, f"{output_name}_args.json") + with open(saved_args_file_path, 'w') as f: + json.dump(vars(args), f, indent=4) + + #workflow saving + metadata = {} + if extra_pnginfo is not None: + metadata.update(extra_pnginfo["workflow"]) + + saved_workflow_file_path = os.path.join(output_dir, f"{output_name}_workflow.json") + with open(saved_workflow_file_path, 'w') as f: + json.dump(metadata, f, indent=4) + + #pass args to kohya and initialize trainer + with torch.inference_mode(False): + network_trainer = Sd3NetworkTrainer() + training_loop = network_trainer.init_train(args) + + epochs_count = network_trainer.num_train_epochs + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, epochs_count, args) + + +class SD3TrainLoop: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "steps": ("INT", {"default": 1, "min": 1, "max": 10000, "step": 1, "tooltip": "the step point in training to validate/save"}), + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "INT",) + RETURN_NAMES = ("network_trainer", "steps",) + FUNCTION = "train" + CATEGORY = "FluxTrainer" + + def train(self, network_trainer, steps): + with torch.inference_mode(False): + training_loop = network_trainer["training_loop"] + network_trainer = network_trainer["network_trainer"] + initial_global_step = network_trainer.global_step + + target_global_step = network_trainer.global_step + steps + comfy_pbar = comfy.utils.ProgressBar(steps) + network_trainer.comfy_pbar = comfy_pbar + + network_trainer.optimizer_train_fn() + + while network_trainer.global_step < target_global_step: + steps_done = training_loop( + break_at_steps = target_global_step, + epoch = network_trainer.current_epoch.value, + ) + + # Also break if the global steps have reached the max train steps + if network_trainer.global_step >= network_trainer.args.max_train_steps: + break + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, network_trainer.global_step) + + +class SD3TrainLoRASave: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "save_state": ("BOOLEAN", {"default": False, "tooltip": "save the whole model state as well"}), + "copy_to_comfy_lora_folder": ("BOOLEAN", {"default": False, "tooltip": "copy the lora model to the comfy lora folder"}), + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "STRING", "INT",) + RETURN_NAMES = ("network_trainer","lora_path", "steps",) + FUNCTION = "save" + CATEGORY = "FluxTrainer" + + def save(self, network_trainer, save_state, copy_to_comfy_lora_folder): + import shutil + with torch.inference_mode(False): + trainer = network_trainer["network_trainer"] + global_step = trainer.global_step + + ckpt_name = train_util.get_step_ckpt_name(trainer.args, "." + trainer.args.save_model_as, global_step) + trainer.save_model(ckpt_name, trainer.accelerator.unwrap_model(trainer.network), global_step, trainer.current_epoch.value + 1) + + remove_step_no = train_util.get_remove_step_no(trainer.args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(trainer.args, "." + trainer.args.save_model_as, remove_step_no) + trainer.remove_model(remove_ckpt_name) + + if save_state: + train_util.save_and_remove_state_stepwise(trainer.args, trainer.accelerator, global_step) + + lora_path = os.path.join(trainer.args.output_dir, ckpt_name) + if copy_to_comfy_lora_folder: + destination_dir = os.path.join(folder_paths.models_dir, "loras", "flux_trainer") + os.makedirs(destination_dir, exist_ok=True) + shutil.copy(lora_path, os.path.join(destination_dir, ckpt_name)) + + + return (network_trainer, lora_path, global_step) + + + +class SD3TrainEnd: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "save_state": ("BOOLEAN", {"default": True}), + }, + } + + RETURN_TYPES = ("STRING", "STRING", "STRING",) + RETURN_NAMES = ("lora_name", "metadata", "lora_path",) + FUNCTION = "endtrain" + CATEGORY = "FluxTrainer" + OUTPUT_NODE = True + + def endtrain(self, network_trainer, save_state): + with torch.inference_mode(False): + training_loop = network_trainer["training_loop"] + network_trainer = network_trainer["network_trainer"] + + network_trainer.metadata["ss_epoch"] = str(network_trainer.num_train_epochs) + network_trainer.metadata["ss_training_finished_at"] = str(time.time()) + + network = network_trainer.accelerator.unwrap_model(network_trainer.network) + + network_trainer.accelerator.end_training() + network_trainer.optimizer_eval_fn() + + if save_state: + train_util.save_state_on_train_end(network_trainer.args, network_trainer.accelerator) + + ckpt_name = train_util.get_last_ckpt_name(network_trainer.args, "." + network_trainer.args.save_model_as) + network_trainer.save_model(ckpt_name, network, network_trainer.global_step, network_trainer.num_train_epochs, force_sync_upload=True) + logger.info("model saved.") + + final_lora_name = str(network_trainer.args.output_name) + final_lora_path = os.path.join(network_trainer.args.output_dir, ckpt_name) + + # metadata + metadata = json.dumps(network_trainer.metadata, indent=2) + + training_loop = None + network_trainer = None + mm.soft_empty_cache() + + return (final_lora_name, metadata, final_lora_path) + +class SD3TrainValidationSettings: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "steps": ("INT", {"default": 20, "min": 1, "max": 256, "step": 1, "tooltip": "sampling steps"}), + "width": ("INT", {"default": 1024, "min": 64, "max": 4096, "step": 8, "tooltip": "image width"}), + "height": ("INT", {"default": 1024, "min": 64, "max": 4096, "step": 8, "tooltip": "image height"}), + "guidance_scale": ("FLOAT", {"default": 4, "min": 1.0, "max": 32.0, "step": 0.05, "tooltip": "guidance scale"}), + "seed": ("INT", {"default": 42,"min": 0, "max": 0xffffffffffffffff, "step": 1}), + }, + } + + RETURN_TYPES = ("VALSETTINGS", ) + RETURN_NAMES = ("validation_settings", ) + FUNCTION = "set" + CATEGORY = "FluxTrainer" + + def set(self, **kwargs): + validation_settings = kwargs + print(validation_settings) + + return (validation_settings,) + +class SD3TrainValidate: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "network_trainer": ("NETWORKTRAINER",), + }, + "optional": { + "validation_settings": ("VALSETTINGS",), + } + } + + RETURN_TYPES = ("NETWORKTRAINER", "IMAGE",) + RETURN_NAMES = ("network_trainer", "validation_images",) + FUNCTION = "validate" + CATEGORY = "FluxTrainer" + + def validate(self, network_trainer, validation_settings=None): + training_loop = network_trainer["training_loop"] + network_trainer = network_trainer["network_trainer"] + + params = ( + network_trainer.current_epoch.value, + network_trainer.global_step, + validation_settings + ) + network_trainer.optimizer_eval_fn() + with torch.inference_mode(False): + image_tensors = network_trainer.sample_images(*params) + + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, (0.5 * (image_tensors + 1.0)).cpu().float(),) + +NODE_CLASS_MAPPINGS = { + "SD3ModelSelect": SD3ModelSelect, + "InitSD3LoRATraining": InitSD3LoRATraining, + "SD3TrainValidationSettings": SD3TrainValidationSettings, +} +NODE_DISPLAY_NAME_MAPPINGS = { + "SD3ModelSelect": "SD3 Model Select", + "InitSD3LoRATraining": "Init SD3 LoRA Training", + "SD3TrainValidationSettings": "SD3 Train Validation Settings", +} diff --git a/nodes_sdxl.py b/nodes_sdxl.py new file mode 100644 index 0000000000000000000000000000000000000000..753fc9f62c9de15f8064deea4271eb47124472ed --- /dev/null +++ b/nodes_sdxl.py @@ -0,0 +1,465 @@ +import os +import torch + +import folder_paths +import comfy.model_management as mm +import comfy.utils +import toml +import json +import time +import shutil +import shlex + +script_directory = os.path.dirname(os.path.abspath(__file__)) + +from .sdxl_train_network import SdxlNetworkTrainer +from .library import sdxl_train_util +from .library.device_utils import init_ipex +init_ipex() + +from .library import train_util +from .train_network import setup_parser as train_network_setup_parser + +import logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +class SDXLModelSelect: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "checkpoint": (folder_paths.get_filename_list("checkpoints"), ), + }, + "optional": { + "lora_path": ("STRING",{"multiline": True, "forceInput": True, "default": "", "tooltip": "pre-trained LoRA path to load (network_weights)"}), + } + } + + RETURN_TYPES = ("TRAIN_SDXL_MODELS",) + RETURN_NAMES = ("sdxl_models",) + FUNCTION = "loadmodel" + CATEGORY = "FluxTrainer/SDXL" + + def loadmodel(self, checkpoint, lora_path=""): + + checkpoint_path = folder_paths.get_full_path("checkpoints", checkpoint) + + SDXL_models = { + "checkpoint": checkpoint_path, + "lora_path": lora_path + } + + return (SDXL_models,) + +class InitSDXLLoRATraining: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "SDXL_models": ("TRAIN_SDXL_MODELS",), + "dataset": ("JSON",), + "optimizer_settings": ("ARGS",), + "output_name": ("STRING", {"default": "SDXL_lora", "multiline": False}), + "output_dir": ("STRING", {"default": "SDXL_trainer_output", "multiline": False, "tooltip": "path to dataset, root is the 'ComfyUI' folder, with windows portable 'ComfyUI_windows_portable'"}), + "network_dim": ("INT", {"default": 16, "min": 1, "max": 100000, "step": 1, "tooltip": "network dim"}), + "network_alpha": ("FLOAT", {"default": 16, "min": 0.0, "max": 2048.0, "step": 0.01, "tooltip": "network alpha"}), + "learning_rate": ("FLOAT", {"default": 1e-6, "min": 0.0, "max": 10.0, "step": 0.0000001, "tooltip": "learning rate"}), + "max_train_steps": ("INT", {"default": 1500, "min": 1, "max": 100000, "step": 1, "tooltip": "max number of training steps"}), + "cache_latents": (["disk", "memory", "disabled"], {"tooltip": "caches text encoder outputs"}), + "cache_text_encoder_outputs": (["disk", "memory", "disabled"], {"tooltip": "caches text encoder outputs"}), + "highvram": ("BOOLEAN", {"default": False, "tooltip": "memory mode"}), + "blocks_to_swap": ("INT", {"default": 0, "min": 0, "max": 100, "step": 1, "tooltip": "option for memory use reduction. The maximum number of blocks that can be swapped is 36 for SDXL.5L and 22 for SDXL.5M"}), + "fp8_base": ("BOOLEAN", {"default": False, "tooltip": "use fp8 for base model"}), + "gradient_dtype": (["fp32", "fp16", "bf16"], {"default": "fp32", "tooltip": "the actual dtype training uses"}), + "save_dtype": (["fp32", "fp16", "bf16", "fp8_e4m3fn", "fp8_e5m2"], {"default": "fp16", "tooltip": "the dtype to save checkpoints as"}), + "attention_mode": (["sdpa", "xformers", "disabled"], {"default": "sdpa", "tooltip": "memory efficient attention mode"}), + "train_text_encoder": (['disabled', 'clip_l',], {"default": 'disabled', "tooltip": "also train the selected text encoders using specified dtype, T5 can not be trained without clip_l"}), + "clip_l_lr": ("FLOAT", {"default": 0, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "text encoder learning rate"}), + "clip_g_lr": ("FLOAT", {"default": 0, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "text encoder learning rate"}), + "sample_prompts_pos": ("STRING", {"multiline": True, "default": "illustration of a kitten | photograph of a turtle", "tooltip": "validation sample prompts, for multiple prompts, separate by `|`"}), + "sample_prompts_neg": ("STRING", {"multiline": True, "default": "", "tooltip": "validation sample prompts, for multiple prompts, separate by `|`"}), + "gradient_checkpointing": (["enabled", "disabled"], {"default": "enabled", "tooltip": "use gradient checkpointing"}), + }, + "optional": { + "additional_args": ("STRING", {"multiline": True, "default": "", "tooltip": "additional args to pass to the training command"}), + "resume_args": ("ARGS", {"default": "", "tooltip": "resume args to pass to the training command"}), + "block_args": ("ARGS", {"default": "", "tooltip": "limit the blocks used in the LoRA"}), + "loss_args": ("ARGS", {"default": "", "tooltip": "loss args"}), + "network_config": ("NETWORK_CONFIG", {"tooltip": "additional network config"}), + }, + "hidden": { + "prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO" + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "INT", "KOHYA_ARGS",) + RETURN_NAMES = ("network_trainer", "epochs_count", "args",) + FUNCTION = "init_training" + CATEGORY = "FluxTrainer/SDXL" + + def init_training(self, SDXL_models, dataset, optimizer_settings, sample_prompts_pos, sample_prompts_neg, output_name, attention_mode, + gradient_dtype, save_dtype, additional_args=None, resume_args=None, train_text_encoder='disabled', + gradient_checkpointing="enabled", prompt=None, extra_pnginfo=None, clip_l_lr=0, clip_g_lr=0, loss_args=None, network_config=None, **kwargs): + mm.soft_empty_cache() + + output_dir = os.path.abspath(kwargs.get("output_dir")) + os.makedirs(output_dir, exist_ok=True) + + total, used, free = shutil.disk_usage(output_dir) + + required_free_space = 2 * (2**30) + if free <= required_free_space: + raise ValueError(f"Insufficient disk space. Required: {required_free_space/2**30}GB. Available: {free/2**30}GB") + + dataset_config = dataset["datasets"] + dataset_toml = toml.dumps(json.loads(dataset_config)) + + parser = train_network_setup_parser() + #sdxl_train_util.add_sdxl_training_arguments(parser) + if additional_args is not None: + print(f"additional_args: {additional_args}") + args, _ = parser.parse_known_args(args=shlex.split(additional_args)) + else: + args, _ = parser.parse_known_args() + + if kwargs.get("cache_latents") == "memory": + kwargs["cache_latents"] = True + kwargs["cache_latents_to_disk"] = False + elif kwargs.get("cache_latents") == "disk": + kwargs["cache_latents"] = True + kwargs["cache_latents_to_disk"] = True + kwargs["caption_dropout_rate"] = 0.0 + kwargs["shuffle_caption"] = False + kwargs["token_warmup_step"] = 0.0 + kwargs["caption_tag_dropout_rate"] = 0.0 + else: + kwargs["cache_latents"] = False + kwargs["cache_latents_to_disk"] = False + + if kwargs.get("cache_text_encoder_outputs") == "memory": + kwargs["cache_text_encoder_outputs"] = True + kwargs["cache_text_encoder_outputs_to_disk"] = False + elif kwargs.get("cache_text_encoder_outputs") == "disk": + kwargs["cache_text_encoder_outputs"] = True + kwargs["cache_text_encoder_outputs_to_disk"] = True + else: + kwargs["cache_text_encoder_outputs"] = False + kwargs["cache_text_encoder_outputs_to_disk"] = False + + if '|' in sample_prompts_pos: + positive_prompts = sample_prompts_pos.split('|') + else: + positive_prompts = [sample_prompts_pos] + + if '|' in sample_prompts_neg: + negative_prompts = sample_prompts_neg.split('|') + else: + negative_prompts = [sample_prompts_neg] + + config_dict = { + "sample_prompts": positive_prompts, + "negative_prompts": negative_prompts, + "save_precision": save_dtype, + "mixed_precision": "bf16", + "num_cpu_threads_per_process": 1, + "pretrained_model_name_or_path": SDXL_models["checkpoint"], + "save_model_as": "safetensors", + "persistent_data_loader_workers": False, + "max_data_loader_n_workers": 0, + "seed": 42, + "network_module": ".networks.lora" if network_config is None else network_config["network_module"], + "dataset_config": dataset_toml, + "output_name": f"{output_name}_rank{kwargs.get('network_dim')}_{save_dtype}", + "loss_type": "l2", + "alpha_mask": dataset["alpha_mask"], + "network_train_unet_only": True if train_text_encoder == 'disabled' else False, + "disable_mmap_load_safetensors": False, + "network_args": None if network_config is None else network_config["network_args"], + } + attention_settings = { + "sdpa": {"mem_eff_attn": True, "xformers": False, "spda": True}, + "xformers": {"mem_eff_attn": True, "xformers": True, "spda": False} + } + config_dict.update(attention_settings.get(attention_mode, {})) + + gradient_dtype_settings = { + "fp16": {"full_fp16": True, "full_bf16": False, "mixed_precision": "fp16"}, + "bf16": {"full_bf16": True, "full_fp16": False, "mixed_precision": "bf16"} + } + config_dict.update(gradient_dtype_settings.get(gradient_dtype, {})) + + if train_text_encoder != 'disabled': + config_dict["text_encoder_lr"] = [clip_l_lr, clip_g_lr] + + #network args + additional_network_args = [] + + # Handle network_args in args Namespace + if hasattr(args, 'network_args') and isinstance(args.network_args, list): + args.network_args.extend(additional_network_args) + else: + setattr(args, 'network_args', additional_network_args) + + if gradient_checkpointing == "disabled": + config_dict["gradient_checkpointing"] = False + elif gradient_checkpointing == "enabled_with_cpu_offloading": + config_dict["gradient_checkpointing"] = True + config_dict["cpu_offload_checkpointing"] = True + else: + config_dict["gradient_checkpointing"] = True + + if SDXL_models["lora_path"]: + config_dict["network_weights"] = SDXL_models["lora_path"] + + config_dict.update(kwargs) + config_dict.update(optimizer_settings) + + if loss_args: + config_dict.update(loss_args) + + if resume_args: + config_dict.update(resume_args) + + for key, value in config_dict.items(): + setattr(args, key, value) + + saved_args_file_path = os.path.join(output_dir, f"{output_name}_args.json") + with open(saved_args_file_path, 'w') as f: + json.dump(vars(args), f, indent=4) + + #workflow saving + metadata = {} + if extra_pnginfo is not None: + metadata.update(extra_pnginfo["workflow"]) + + saved_workflow_file_path = os.path.join(output_dir, f"{output_name}_workflow.json") + with open(saved_workflow_file_path, 'w') as f: + json.dump(metadata, f, indent=4) + + #pass args to kohya and initialize trainer + with torch.inference_mode(False): + network_trainer = SdxlNetworkTrainer() + training_loop = network_trainer.init_train(args) + + epochs_count = network_trainer.num_train_epochs + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, epochs_count, args) + + +class SDXLTrainLoop: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "steps": ("INT", {"default": 1, "min": 1, "max": 10000, "step": 1, "tooltip": "the step point in training to validate/save"}), + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "INT",) + RETURN_NAMES = ("network_trainer", "steps",) + FUNCTION = "train" + CATEGORY = "FluxTrainer/SDXL" + + def train(self, network_trainer, steps): + with torch.inference_mode(False): + training_loop = network_trainer["training_loop"] + network_trainer = network_trainer["network_trainer"] + initial_global_step = network_trainer.global_step + + target_global_step = network_trainer.global_step + steps + comfy_pbar = comfy.utils.ProgressBar(steps) + network_trainer.comfy_pbar = comfy_pbar + + network_trainer.optimizer_train_fn() + + while network_trainer.global_step < target_global_step: + steps_done = training_loop( + break_at_steps = target_global_step, + epoch = network_trainer.current_epoch.value, + ) + + # Also break if the global steps have reached the max train steps + if network_trainer.global_step >= network_trainer.args.max_train_steps: + break + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, network_trainer.global_step) + + +class SDXLTrainLoRASave: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "save_state": ("BOOLEAN", {"default": False, "tooltip": "save the whole model state as well"}), + "copy_to_comfy_lora_folder": ("BOOLEAN", {"default": False, "tooltip": "copy the lora model to the comfy lora folder"}), + }, + } + + RETURN_TYPES = ("NETWORKTRAINER", "STRING", "INT",) + RETURN_NAMES = ("network_trainer","lora_path", "steps",) + FUNCTION = "save" + CATEGORY = "FluxTrainer/SDXL" + + def save(self, network_trainer, save_state, copy_to_comfy_lora_folder): + import shutil + with torch.inference_mode(False): + trainer = network_trainer["network_trainer"] + global_step = trainer.global_step + + ckpt_name = train_util.get_step_ckpt_name(trainer.args, "." + trainer.args.save_model_as, global_step) + trainer.save_model(ckpt_name, trainer.accelerator.unwrap_model(trainer.network), global_step, trainer.current_epoch.value + 1) + + remove_step_no = train_util.get_remove_step_no(trainer.args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(trainer.args, "." + trainer.args.save_model_as, remove_step_no) + trainer.remove_model(remove_ckpt_name) + + if save_state: + train_util.save_and_remove_state_stepwise(trainer.args, trainer.accelerator, global_step) + + lora_path = os.path.join(trainer.args.output_dir, ckpt_name) + if copy_to_comfy_lora_folder: + destination_dir = os.path.join(folder_paths.models_dir, "loras", "flux_trainer") + os.makedirs(destination_dir, exist_ok=True) + shutil.copy(lora_path, os.path.join(destination_dir, ckpt_name)) + + + return (network_trainer, lora_path, global_step) + + + +class SDXLTrainEnd: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "network_trainer": ("NETWORKTRAINER",), + "save_state": ("BOOLEAN", {"default": True}), + }, + } + + RETURN_TYPES = ("STRING", "STRING", "STRING",) + RETURN_NAMES = ("lora_name", "metadata", "lora_path",) + FUNCTION = "endtrain" + CATEGORY = "FluxTrainer/SDXL" + OUTPUT_NODE = True + + def endtrain(self, network_trainer, save_state): + with torch.inference_mode(False): + training_loop = network_trainer["training_loop"] + network_trainer = network_trainer["network_trainer"] + + network_trainer.metadata["ss_epoch"] = str(network_trainer.num_train_epochs) + network_trainer.metadata["ss_training_finished_at"] = str(time.time()) + + network = network_trainer.accelerator.unwrap_model(network_trainer.network) + + network_trainer.accelerator.end_training() + network_trainer.optimizer_eval_fn() + + if save_state: + train_util.save_state_on_train_end(network_trainer.args, network_trainer.accelerator) + + ckpt_name = train_util.get_last_ckpt_name(network_trainer.args, "." + network_trainer.args.save_model_as) + network_trainer.save_model(ckpt_name, network, network_trainer.global_step, network_trainer.num_train_epochs, force_sync_upload=True) + logger.info("model saved.") + + final_lora_name = str(network_trainer.args.output_name) + final_lora_path = os.path.join(network_trainer.args.output_dir, ckpt_name) + + # metadata + metadata = json.dumps(network_trainer.metadata, indent=2) + + training_loop = None + network_trainer = None + mm.soft_empty_cache() + + return (final_lora_name, metadata, final_lora_path) + +class SDXLTrainValidationSettings: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "steps": ("INT", {"default": 20, "min": 1, "max": 256, "step": 1, "tooltip": "sampling steps"}), + "width": ("INT", {"default": 1024, "min": 64, "max": 4096, "step": 8, "tooltip": "image width"}), + "height": ("INT", {"default": 1024, "min": 64, "max": 4096, "step": 8, "tooltip": "image height"}), + "guidance_scale": ("FLOAT", {"default": 7.5, "min": 1.0, "max": 32.0, "step": 0.05, "tooltip": "guidance scale"}), + "sampler": (["ddim", "ddpm", "pndm", "lms", "euler", "euler_a", "dpmsolver", "dpmsingle", "heun", "dpm_2", "dpm_2_a",], {"default": "dpm_2", "tooltip": "sampler"}), + "seed": ("INT", {"default": 42,"min": 0, "max": 0xffffffffffffffff, "step": 1}), + }, + } + + RETURN_TYPES = ("VALSETTINGS", ) + RETURN_NAMES = ("validation_settings", ) + FUNCTION = "set" + CATEGORY = "FluxTrainer/SDXL" + + def set(self, **kwargs): + validation_settings = kwargs + print(validation_settings) + + return (validation_settings,) + +class SDXLTrainValidate: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "network_trainer": ("NETWORKTRAINER",), + }, + "optional": { + "validation_settings": ("VALSETTINGS",), + } + } + + RETURN_TYPES = ("NETWORKTRAINER", "IMAGE",) + RETURN_NAMES = ("network_trainer", "validation_images",) + FUNCTION = "validate" + CATEGORY = "FluxTrainer/SDXL" + + def validate(self, network_trainer, validation_settings=None): + training_loop = network_trainer["training_loop"] + network_trainer = network_trainer["network_trainer"] + + params = ( + network_trainer.accelerator, + network_trainer.args, + network_trainer.current_epoch.value, + network_trainer.global_step, + network_trainer.accelerator.device, + network_trainer.vae, + network_trainer.tokenizers, + network_trainer.text_encoder, + network_trainer.unet, + validation_settings, + ) + network_trainer.optimizer_eval_fn() + with torch.inference_mode(False): + image_tensors = network_trainer.sample_images(*params) + + + trainer = { + "network_trainer": network_trainer, + "training_loop": training_loop, + } + return (trainer, (0.5 * (image_tensors + 1.0)).cpu().float(),) + +NODE_CLASS_MAPPINGS = { + "SDXLModelSelect": SDXLModelSelect, + "InitSDXLLoRATraining": InitSDXLLoRATraining, + "SDXLTrainValidationSettings": SDXLTrainValidationSettings, + "SDXLTrainValidate": SDXLTrainValidate, + +} +NODE_DISPLAY_NAME_MAPPINGS = { + "SDXLModelSelect": "SDXL Model Select", + "InitSDXLLoRATraining": "Init SDXL LoRA Training", + "SDXLTrainValidationSettings": "SDXL Train Validation Settings", + "SDXLTrainValidate": "SDXL Train Validate", +} diff --git a/pinokio.js b/pinokio.js new file mode 100644 index 0000000000000000000000000000000000000000..9562448e3955e7f6d3c369a46e9042fac5adb26c --- /dev/null +++ b/pinokio.js @@ -0,0 +1,95 @@ +const path = require('path') +module.exports = { + version: "3.2", + title: "fluxgym", + description: "[NVIDIA Only] Dead simple web UI for training FLUX LoRA with LOW VRAM support (From 12GB)", + icon: "icon.png", + menu: async (kernel, info) => { + let installed = info.exists("env") + let running = { + install: info.running("install.js"), + start: info.running("start.js"), + update: info.running("update.js"), + reset: info.running("reset.js") + } + if (running.install) { + return [{ + default: true, + icon: "fa-solid fa-plug", + text: "Installing", + href: "install.js", + }] + } else if (installed) { + if (running.start) { + let local = info.local("start.js") + if (local && local.url) { + return [{ + default: true, + icon: "fa-solid fa-rocket", + text: "Open Web UI", + href: local.url, + }, { + icon: 'fa-solid fa-terminal', + text: "Terminal", + href: "start.js", + }, { + icon: "fa-solid fa-flask", + text: "Outputs", + href: "outputs?fs" + }] + } else { + return [{ + default: true, + icon: 'fa-solid fa-terminal', + text: "Terminal", + href: "start.js", + }] + } + } else if (running.update) { + return [{ + default: true, + icon: 'fa-solid fa-terminal', + text: "Updating", + href: "update.js", + }] + } else if (running.reset) { + return [{ + default: true, + icon: 'fa-solid fa-terminal', + text: "Resetting", + href: "reset.js", + }] + } else { + return [{ + default: true, + icon: "fa-solid fa-power-off", + text: "Start", + href: "start.js", + }, { + icon: "fa-solid fa-flask", + text: "Outputs", + href: "sd-scripts/fluxgym/outputs?fs" + }, { + icon: "fa-solid fa-plug", + text: "Update", + href: "update.js", + }, { + icon: "fa-solid fa-plug", + text: "Install", + href: "install.js", + }, { + icon: "fa-regular fa-circle-xmark", + text: "Reset", + href: "reset.js", + }] + } + } else { + return [{ + default: true, + icon: "fa-solid fa-plug", + text: "Install", + href: "install.js", + }] + } + } +} diff --git a/pinokio_meta.json b/pinokio_meta.json new file mode 100644 index 0000000000000000000000000000000000000000..c5d4580458be87c97c311020241b5c49ef93c870 --- /dev/null +++ b/pinokio_meta.json @@ -0,0 +1,39 @@ +{ + "posts": [ + "https://x.com/cocktailpeanut/status/1851721405408166064", + "https://x.com/cocktailpeanut/status/1835719701172756592", + "https://x.com/LikeToasters/status/1834258975384092858", + "https://x.com/cocktailpeanut/status/1834245329627009295", + "https://x.com/jkch0205/status/1834003420132614450", + "https://x.com/huwhitememes/status/1834074992209699132", + "https://x.com/GorillaRogueGam/status/1834148656791888139", + "https://x.com/cocktailpeanut/status/1833964839519068303", + "https://x.com/cocktailpeanut/status/1833935061907079521", + "https://x.com/cocktailpeanut/status/1833940728881242135", + "https://x.com/cocktailpeanut/status/1833881392482066638", + "https://x.com/Alone1Moon/status/1833348850662445369", + "https://x.com/_f_ai_9/status/1833485349995397167", + "https://x.com/intocryptoast/status/1833061082862412186", + "https://x.com/cocktailpeanut/status/1833888423716827321", + "https://x.com/cocktailpeanut/status/1833884852992516596", + "https://x.com/cocktailpeanut/status/1833885335077417046", + "https://x.com/NiwonArt/status/1833565746624139650", + "https://x.com/cocktailpeanut/status/1833884361986380117", + "https://x.com/NiwonArt/status/1833599399764889685", + "https://x.com/LikeToasters/status/1832934391217045913", + "https://x.com/cocktailpeanut/status/1832924887456817415", + "https://x.com/cocktailpeanut/status/1832927154536902897", + "https://x.com/YabaiHamster/status/1832697724690386992", + "https://x.com/cocktailpeanut/status/1832747889497366706", + "https://x.com/PhotogenicWeekE/status/1832720544959185202", + "https://x.com/zuzaritt/status/1832748542164652390", + "https://x.com/foxyy4i/status/1832764883710185880", + "https://x.com/waynedahlberg/status/1832226132999213095", + "https://x.com/PhotoGarrido/status/1832214644515041770", + "https://x.com/cocktailpeanut/status/1832787205774786710", + "https://x.com/cocktailpeanut/status/1832151307198541961", + "https://x.com/cocktailpeanut/status/1832145996014612735", + "https://x.com/cocktailpeanut/status/1832084951115972653", + "https://x.com/cocktailpeanut/status/1832091112086843684" + ] +} diff --git a/publish_to_hf.png b/publish_to_hf.png new file mode 100644 index 0000000000000000000000000000000000000000..25c572624bcbbcb5f0a5d626ce99daaf05da40ad --- /dev/null +++ b/publish_to_hf.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cac2aa25db8911b38ed7e084bbbafb226252e26935dbb107ee66b8cc626a95e6 +size 418251 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..291426c178063a47091e4d82760b7f702ebeb9bd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,43 @@ +safetensors +git+https://github.com/huggingface/diffusers.git +gradio_logsview@https://huggingface.co/spaces/cocktailpeanut/gradio_logsview/resolve/main/gradio_logsview-0.0.17-py3-none-any.whl +transformers +lycoris-lora==1.8.3 +flatten_json +pyyaml +oyaml +tensorboard +kornia +invisible-watermark +einops +accelerate +toml +albumentations +pydantic +omegaconf +k-diffusion +open_clip_torch +timm +prodigyopt +controlnet_aux==0.0.7 +python-dotenv +bitsandbytes +hf_transfer +lpips +pytorch_fid +optimum-quanto +sentencepiece +huggingface_hub +peft +gradio +python-slugify +imagesize +pydantic==2.9.2 +slugify +easygui +argparse +pytorch-lightning==1.9.0 +# triton +comfy +gradio +voluptuous diff --git a/reset.js b/reset.js new file mode 100644 index 0000000000000000000000000000000000000000..0278948c73f96521352a9358dc55b62010efae3e --- /dev/null +++ b/reset.js @@ -0,0 +1,13 @@ +module.exports = { + run: [{ + method: "fs.rm", + params: { + path: "sd-scripts" + } + }, { + method: "fs.rm", + params: { + path: "env" + } + }] +} diff --git a/sample.png b/sample.png new file mode 100644 index 0000000000000000000000000000000000000000..a7f29b47eabfec81b53f82b6cc336bbd2863c468 --- /dev/null +++ b/sample.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a1670e3ce2a35d0cffec798ea04f4216b7d4d766e1e785ef23e94f6d2d22ff1 +size 1293751 diff --git a/sample_fields.png b/sample_fields.png new file mode 100644 index 0000000000000000000000000000000000000000..bc50bee365e03c2e25779299dc854712838f61c3 Binary files /dev/null and b/sample_fields.png differ diff --git a/screenshot.png b/screenshot.png new file mode 100644 index 0000000000000000000000000000000000000000..4a4fd700582986ab500e118e5764b04bc4c14f85 --- /dev/null +++ b/screenshot.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cde964e4a233bf3ad7219ac91058f805ae0cbc0f853b7f6aa552af5b6f8c5c8a +size 242712 diff --git a/sd3_train_network.py b/sd3_train_network.py new file mode 100644 index 0000000000000000000000000000000000000000..9c4273a5ac74b82cc8c9b60f5285012d6e87b62f --- /dev/null +++ b/sd3_train_network.py @@ -0,0 +1,486 @@ +import argparse +import math + +from typing import Any, Optional + +import torch +from accelerate import Accelerator +from .library import sd3_models, strategy_sd3, utils +from .library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from .library import flux_models, flux_utils, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3, train_util +from . import train_network +from .library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class Sd3NetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + self.sample_prompts_te_outputs = None + + def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup): + # super().assert_extra_args(args, train_dataset_group) + # sdxl_train_util.verify_sdxl_training_args(args) + + if args.fp8_base_unet: + args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for SD3 + + if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: + logger.warning( + "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" + ) + args.cache_text_encoder_outputs = True + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # prepare CLIP-L/CLIP-G/T5XXL training flags + self.train_clip = not args.network_train_unet_only + self.train_t5xxl = False # default is False even if args.network_train_unet_only is False + + if args.max_token_length is not None: + logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + + assert ( + args.blocks_to_swap is None or args.blocks_to_swap == 0 + ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" + + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + + # enumerate resolutions from dataset for positional embeddings + self.resolutions = train_dataset_group.get_resolutions() + + def load_target_model(self, args, weight_dtype, accelerator): + # currently offload to cpu for some models + + # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) + loading_dtype = None if args.fp8_base else weight_dtype + + # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future + state_dict = utils.load_safetensors( + args.pretrained_model_name_or_path, "cpu", disable_mmap=args.disable_mmap_load_safetensors, dtype=loading_dtype + ) + mmdit = sd3_utils.load_mmdit(state_dict, loading_dtype, "cpu") + self.model_type = mmdit.model_type + mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate) + + # set resolutions for positional embeddings + if args.enable_scaled_pos_embed: + latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in self.resolutions] # 8 is stride for latent + latent_sizes = list(set(latent_sizes)) # remove duplicates + logger.info(f"Prepare scaled positional embeddings for resolutions: {self.resolutions}, sizes: {latent_sizes}") + mmdit.enable_scaled_pos_embed(True, latent_sizes) + + if args.fp8_base: + # check dtype of model + if mmdit.dtype == torch.float8_e4m3fnuz or mmdit.dtype == torch.float8_e5m2 or mmdit.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {mmdit.dtype}") + elif mmdit.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 SD3 model") + else: + logger.info( + "Cast SD3 model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint." + ) + mmdit.to(torch.float8_e4m3fn) + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + if self.is_swapping_blocks: + # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. + logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") + mmdit.enable_block_swap(args.blocks_to_swap, accelerator.device) + + clip_l = sd3_utils.load_clip_l( + args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + clip_l.eval() + clip_g = sd3_utils.load_clip_g( + args.clip_g, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + clip_g.eval() + + # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) + if args.fp8_base and not args.fp8_base_unet: + loading_dtype = None # as is + else: + loading_dtype = weight_dtype + + # loading t5xxl to cpu takes a long time, so we should load to gpu in future + t5xxl = sd3_utils.load_t5xxl( + args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + t5xxl.eval() + if args.fp8_base and not args.fp8_base_unet: + # check dtype of model + if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: + raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") + elif t5xxl.dtype == torch.float8_e4m3fn: + logger.info("Loaded fp8 T5XXL model") + + vae = sd3_utils.load_vae( + args.vae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict + ) + + return mmdit.model_type, [clip_l, clip_g, t5xxl], vae, mmdit + + def get_tokenize_strategy(self, args): + logger.info(f"t5xxl_max_token_length: {args.t5xxl_max_token_length}") + return strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy): + return [tokenize_strategy.clip_l, tokenize_strategy.clip_g, tokenize_strategy.t5xxl] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sd3.Sd3TextEncodingStrategy( + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, + args.clip_l_dropout_rate, + args.clip_g_dropout_rate, + args.t5_dropout_rate, + ) + + def post_process_network(self, args, accelerator, network, text_encoders, unet): + # check t5xxl is trained or not + self.train_t5xxl = network.train_t5xxl + + if self.train_t5xxl and args.cache_text_encoder_outputs: + raise ValueError( + "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません" + ) + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + if args.cache_text_encoder_outputs: + if self.train_clip and not self.train_t5xxl: + return text_encoders[0:2] + [None] # only CLIP-L/CLIP-G is needed for encoding because T5XXL is cached + else: + return None # no text encoders are needed for encoding because both are cached + else: + return text_encoders # CLIP-L, CLIP-G and T5XXL are needed for encoding + + def get_text_encoders_train_flags(self, args, text_encoders): + return [self.train_clip, self.train_clip, self.train_t5xxl] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + # if the text encoders is trained, we need tokenization, so is_partial is True + return strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, + args.text_encoder_batch_size, + args.skip_cache_check, + is_partial=self.train_clip or self.train_t5xxl, + apply_lg_attn_mask=args.apply_lg_attn_mask, + apply_t5_attn_mask=args.apply_t5_attn_mask, + ) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + ): + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + logger.info("move text encoders to gpu") + text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[1].to(accelerator.device, dtype=weight_dtype) # always not fp8 + text_encoders[2].to(accelerator.device) # may be fp8 + + if text_encoders[2].dtype == torch.float8_e4m3fn: + # if we load fp8 weights, the model is already fp8, so we use it as is + self.prepare_text_encoder_fp8(2, text_encoders[2], text_encoders[2].dtype, weight_dtype) + else: + # otherwise, we need to convert it to target dtype + text_encoders[2].to(weight_dtype) + + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) + + # cache sample prompts + if args.sample_prompts is not None: + logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") + + tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() + text_encoding_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + + prompts = [] + for line in args.sample_prompts: + line = line.strip() + if len(line) > 0 and line[0] != "#": + prompts.append(line) + + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + from .library.train_util import line_to_prompt_dict + + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs + with accelerator.autocast(), torch.no_grad(): + for prompt_dict in prompts: + for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: + if p not in sample_prompts_te_outputs: + logger.info(f"cache Text Encoder outputs for prompt: {p}") + tokens_and_masks = tokenize_strategy.tokenize(p) + sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( + tokenize_strategy, + text_encoders, + tokens_and_masks, + args.apply_lg_attn_mask, + args.apply_t5_attn_mask, + ) + self.sample_prompts_te_outputs = sample_prompts_te_outputs + + accelerator.wait_for_everyone() + + # move back to cpu + if not self.is_train_text_encoder(args): + logger.info("move CLIP-L back to cpu") + text_encoders[0].to("cpu") + logger.info("move CLIP-G back to cpu") + text_encoders[1].to("cpu") + logger.info("move t5XXL back to cpu") + text_encoders[2].to("cpu") + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) + text_encoders[2].to(accelerator.device) + + # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + # # get size embeddings + # orig_size = batch["original_sizes_hw"] + # crop_size = batch["crop_top_lefts"] + # target_size = batch["target_sizes_hw"] + # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # # concat embeddings + # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds + # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + # return noise_pred + + def sample_images(self, epoch, global_step, validation_settings): + text_encoders = self.get_models_for_text_encoding(self.args, self.accelerator, self.text_encoder) + image_tensors = sd3_train_utils.sample_images( + self.accelerator, self.args, epoch, global_step, self.unet, self.vae, text_encoders, self.sample_prompts_te_outputs, validation_settings + ) + + return image_tensors.permute(0, 2, 3, 1) + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + # this scheduler is not used in training, but used to get num_train_timesteps etc. + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images) + + def shift_scale_latents(self, args, latents): + return sd3_models.SDVAE.process_in(latents) + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet: flux_models.Flux, + network, + weight_dtype, + train_unet, + ): + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps( + args, latents, noise, accelerator.device, weight_dtype + ) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + noisy_model_input.requires_grad_(True) + for t in text_encoder_conds: + if t is not None and t.dtype.is_floating_point: + t.requires_grad_(True) + + # Predict the noise residual + lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_conds + text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() + context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled) + if not args.apply_lg_attn_mask: + l_attn_mask = None + g_attn_mask = None + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + # call model + with accelerator.autocast(): + # TODO support attention mask + model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled) + + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + # Preconditioning of the model outputs. + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = sd3_train_utils.compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = latents + + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(), accelerator.autocast(): + model_pred_prior = unet( + noisy_model_input[diff_output_pr_indices], + timesteps[diff_output_pr_indices], + context=context[diff_output_pr_indices], + y=lg_pooled[diff_output_pr_indices], + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + + model_pred_prior = model_pred_prior * (-sigmas[diff_output_pr_indices]) + noisy_model_input[diff_output_pr_indices] + + # weighting for differential output preservation is not needed because it is already applied + + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + + return model_pred, target, timesteps, weighting + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + return loss + + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, False, True, False, sd3=self.model_type) + + def update_metadata(self, metadata, args): + metadata["ss_apply_lg_attn_mask"] = args.apply_lg_attn_mask + metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask + metadata["ss_weighting_scheme"] = args.weighting_scheme + metadata["ss_logit_mean"] = args.logit_mean + metadata["ss_logit_std"] = args.logit_std + metadata["ss_mode_scale"] = args.mode_scale + + def is_text_encoder_not_needed_for_training(self, args): + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + if index == 0 or index == 1: # CLIP-L/CLIP-G + return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder) + else: # T5XXL + text_encoder.encoder.embed_tokens.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + if index == 0 or index == 1: # CLIP-L/CLIP-G + clip_type = "CLIP-L" if index == 0 else "CLIP-G" + logger.info(f"prepare CLIP-{clip_type} for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") + text_encoder.to(te_weight_dtype) # fp8 + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + else: # T5XXL + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: + logger.info(f"T5XXL already prepared for fp8") + else: + logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") + text_encoder.to(te_weight_dtype) # fp8 + prepare_fp8(text_encoder, weight_dtype) + + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + # drop cached text encoder outputs + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) + batch["text_encoder_outputs_list"] = text_encoder_outputs_list + + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + if not self.is_swapping_blocks: + return super().prepare_unet_with_accelerator(args, accelerator, unet) + + # if we doesn't swap blocks, we can move the model to device + mmdit: sd3_models.MMDiT = unet + mmdit = accelerator.prepare(mmdit, device_placement=[not self.is_swapping_blocks]) + accelerator.unwrap_model(mmdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage + accelerator.unwrap_model(mmdit).prepare_block_swap_before_forward() + + return mmdit + + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + train_util.add_dit_training_arguments(parser) + sd3_train_utils.add_sd3_training_arguments(parser) + return parser diff --git a/sdxl_train_network.py b/sdxl_train_network.py new file mode 100644 index 0000000000000000000000000000000000000000..3fc3c3ac3db947a17e1c6b7fc95f017c6f9402c1 --- /dev/null +++ b/sdxl_train_network.py @@ -0,0 +1,228 @@ +import argparse +from typing import List, Optional + +import torch +from accelerate import Accelerator +from .library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from .library import sdxl_model_util, sdxl_train_util, strategy_base, strategy_sd, strategy_sdxl, train_util +from . import train_network +from .library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class SdxlNetworkTrainer(train_network.NetworkTrainer): + def __init__(self): + super().__init__() + self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR + self.is_sdxl = True + + def assert_extra_args(self, args, train_dataset_group): + super().assert_extra_args(args, train_dataset_group) + sdxl_train_util.verify_sdxl_training_args(args) + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + assert ( + args.network_train_unet_only or not args.cache_text_encoder_outputs + ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" + + train_dataset_group.verify_bucket_reso_steps(32) + + def load_target_model(self, args, weight_dtype, accelerator): + ( + load_stable_diffusion_format, + text_encoder1, + text_encoder2, + vae, + unet, + logit_scale, + ckpt_info, + ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) + + self.load_stable_diffusion_format = load_stable_diffusion_format + self.logit_scale = logit_scale + self.ckpt_info = ckpt_info + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) + + return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet + + def get_tokenize_strategy(self, args): + return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy): + return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sdxl.SdxlTextEncodingStrategy() + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + return text_encoders + [accelerator.unwrap_model(text_encoders[-1])] + + def get_text_encoder_outputs_caching_strategy(self, args): + if args.cache_text_encoder_outputs: + return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions + ) + else: + return None + + def cache_text_encoder_outputs_if_needed( + self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype + ): + if args.cache_text_encoder_outputs: + if not args.lowram: + # メモリ消費を減らす + logger.info("move vae and unet to cpu to save memory") + org_vae_device = vae.device + org_unet_device = unet.device + vae.to("cpu") + unet.to("cpu") + clean_memory_on_device(accelerator.device) + + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) + with accelerator.autocast(): + dataset.new_cache_text_encoder_outputs(text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator) + accelerator.wait_for_everyone() + + text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU + text_encoders[1].to("cpu", dtype=torch.float32) + clean_memory_on_device(accelerator.device) + + if not args.lowram: + logger.info("move vae and unet back to original device") + vae.to(org_vae_device) + unet.to(org_unet_device) + else: + # Text Encoderから毎回出力を取得するので、GPUに乗せておく + text_encoders[0].to(accelerator.device, dtype=weight_dtype) + text_encoders[1].to(accelerator.device, dtype=weight_dtype) + + def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): + if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: + input_ids1 = batch["input_ids"] + input_ids2 = batch["input_ids2"] + with torch.enable_grad(): + # Get the text embedding for conditioning + # TODO support weighted captions + # if args.weighted_captions: + # encoder_hidden_states = get_weighted_text_embeddings( + # tokenizer, + # text_encoder, + # batch["captions"], + # accelerator.device, + # args.max_token_length // 75 if args.max_token_length else 1, + # clip_skip=args.clip_skip, + # ) + # else: + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( + args.max_token_length, + input_ids1, + input_ids2, + tokenizers[0], + tokenizers[1], + text_encoders[0], + text_encoders[1], + None if not args.full_fp16 else weight_dtype, + accelerator=accelerator, + ) + else: + encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) + encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) + pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) + + # # verify that the text encoder outputs are correct + # ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl( + # args.max_token_length, + # batch["input_ids"].to(text_encoders[0].device), + # batch["input_ids2"].to(text_encoders[0].device), + # tokenizers[0], + # tokenizers[1], + # text_encoders[0], + # text_encoders[1], + # None if not args.full_fp16 else weight_dtype, + # ) + # b_size = encoder_hidden_states1.shape[0] + # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 + # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 + # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 + # logger.info("text encoder outputs verified") + + return encoder_hidden_states1, encoder_hidden_states2, pool2 + + def call_unet( + self, + args, + accelerator, + unet, + noisy_latents, + timesteps, + text_conds, + batch, + weight_dtype, + indices: Optional[List[int]] = None, + ): + noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + # get size embeddings + orig_size = batch["original_sizes_hw"] + crop_size = batch["crop_top_lefts"] + target_size = batch["target_sizes_hw"] + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # concat embeddings + encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds + vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + if indices is not None and len(indices) > 0: + noisy_latents = noisy_latents[indices] + timesteps = timesteps[indices] + text_embedding = text_embedding[indices] + vector_embedding = vector_embedding[indices] + + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + return noise_pred + + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, validation_settings=None): + image_tensors = sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, validation_settings) + return image_tensors + +def setup_parser() -> argparse.ArgumentParser: + parser = train_network.setup_parser() + sdxl_train_util.add_sdxl_training_arguments(parser) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + trainer = SdxlNetworkTrainer() + trainer.train(args) diff --git a/seed.gif b/seed.gif new file mode 100644 index 0000000000000000000000000000000000000000..c1d209a98db2408230a2d2c67f749fb0bee8793b --- /dev/null +++ b/seed.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:271dbf11ef0c709558bb570c4c2b7765001356eefcbcc9cf0f0713262a91937f +size 3622293 diff --git a/start.js b/start.js new file mode 100644 index 0000000000000000000000000000000000000000..494cb260c56aa96749753b67790aa4af4530e34d --- /dev/null +++ b/start.js @@ -0,0 +1,37 @@ +module.exports = { + daemon: true, + run: [ + { + method: "shell.run", + params: { + venv: "env", // Edit this to customize the venv folder path + env: { + LOG_LEVEL: "DEBUG", + CUDA_VISIBLE_DEVICES: "0" + }, // Edit this to customize environment variables (see documentation) + message: [ + "python app.py", // Edit with your custom commands + ], + on: [{ + // The regular expression pattern to monitor. + // When this pattern occurs in the shell terminal, the shell will return, + // and the script will go onto the next step. + "event": "/http:\/\/\\S+/", + + // "done": true will move to the next step while keeping the shell alive. + // "kill": true will move to the next step after killing the shell. + "done": true + }] + } + }, + { + // This step sets the local variable 'url'. + // This local variable will be used in pinokio.js to display the "Open WebUI" tab when the value is set. + method: "local.set", + params: { + // the input.event is the regular expression match object from the previous step + url: "{{input.event[0]}}" + } + } + ] +} diff --git a/torch.js b/torch.js new file mode 100644 index 0000000000000000000000000000000000000000..4d6e4d1ad3377b2a7300e0da9e1883399743fd70 --- /dev/null +++ b/torch.js @@ -0,0 +1,75 @@ +module.exports = { + run: [ + // windows nvidia + { + "when": "{{platform === 'win32' && gpu === 'nvidia'}}", + "method": "shell.run", + "params": { + "venv": "{{args && args.venv ? args.venv : null}}", + "path": "{{args && args.path ? args.path : '.'}}", + "message": "uv pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 --force-reinstall" + + } + }, + // windows amd + { + "when": "{{platform === 'win32' && gpu === 'amd'}}", + "method": "shell.run", + "params": { + "venv": "{{args && args.venv ? args.venv : null}}", + "path": "{{args && args.path ? args.path : '.'}}", + "message": "uv pip install torch-directml torchaudio torchvision" + } + }, + // windows cpu + { + "when": "{{platform === 'win32' && (gpu !== 'nvidia' && gpu !== 'amd')}}", + "method": "shell.run", + "params": { + "venv": "{{args && args.venv ? args.venv : null}}", + "path": "{{args && args.path ? args.path : '.'}}", + "message": "uv pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --force-reinstall" + } + }, + // mac + { + "when": "{{platform === 'darwin'}}", + "method": "shell.run", + "params": { + "venv": "{{args && args.venv ? args.venv : null}}", + "path": "{{args && args.path ? args.path : '.'}}", + "message": "uv pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --force-reinstall" + } + }, + // linux nvidia + { + "when": "{{platform === 'linux' && gpu === 'nvidia'}}", + "method": "shell.run", + "params": { + "venv": "{{args && args.venv ? args.venv : null}}", + "path": "{{args && args.path ? args.path : '.'}}", + "message": "uv pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 --force-reinstall" + } + }, + // linux rocm (amd) + { + "when": "{{platform === 'linux' && gpu === 'amd'}}", + "method": "shell.run", + "params": { + "venv": "{{args && args.venv ? args.venv : null}}", + "path": "{{args && args.path ? args.path : '.'}}", + "message": "uv pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1 --force-reinstall" + } + }, + // linux cpu + { + "when": "{{platform === 'linux' && (gpu !== 'amd' && gpu !=='nvidia')}}", + "method": "shell.run", + "params": { + "venv": "{{args && args.venv ? args.venv : null}}", + "path": "{{args && args.path ? args.path : '.'}}", + "message": "uv pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --force-reinstall" + } + } + ] +} diff --git a/train_db.py b/train_db.py new file mode 100644 index 0000000000000000000000000000000000000000..3b8d7b205284c09cd867afda259f8028f589958d --- /dev/null +++ b/train_db.py @@ -0,0 +1,558 @@ +# DreamBooth training +# XXX dropped option: fine_tune + +import argparse +import itertools +import math +import os +from multiprocessing import Value +import toml + +from tqdm import tqdm + +import torch +from library import deepspeed_utils, strategy_base +from library.device_utils import init_ipex, clean_memory_on_device + + +init_ipex() + +from accelerate.utils import set_seed +from diffusers import DDPMScheduler + +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import ( + apply_snr_weight, + get_weighted_text_embeddings, + prepare_scheduler_for_custom_training, + pyramid_noise_like, + apply_noise_offset, + scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, + apply_masked_loss, +) +from .utils import setup_logging, add_logging_arguments +import library.strategy_sd as strategy_sd + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +# perlin_noise, + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, False) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + cache_latents = args.cache_latents + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + False, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, args.masked_loss, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + if args.no_token_padding: + train_dataset_group.disable_token_padding() + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + + if args.gradient_accumulation_steps > 1: + logger.warning( + f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong" + ) + logger.warning( + f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です" + ) + + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # モデルを読み込む + text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator) + + # verify load/save model formats + if load_stable_diffusion_format: + src_stable_diffusion_ckpt = args.pretrained_model_name_or_path + src_diffusers_model_path = None + else: + src_stable_diffusion_ckpt = None + src_diffusers_model_path = args.pretrained_model_name_or_path + + if args.save_model_as is None: + save_stable_diffusion_format = load_stable_diffusion_format + use_safetensors = args.use_safetensors + else: + save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors" + use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + + train_dataset_group.new_cache_latents(vae, accelerator) + + vae.to("cpu") + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + # 学習を準備する:モデルを適切な状態にする + train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0 + unet.requires_grad_(True) # 念のため追加 + text_encoder.requires_grad_(train_text_encoder) + if not train_text_encoder: + accelerator.print("Text Encoder is not trained.") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + if train_text_encoder: + if args.learning_rate_te is None: + # wightout list, adamw8bit is crashed + trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters())) + else: + trainable_params = [ + {"params": list(unet.parameters()), "lr": args.learning_rate}, + {"params": list(text_encoder.parameters()), "lr": args.learning_rate_te}, + ] + else: + trainable_params = unet.parameters() + + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + if args.stop_text_encoder_training is None: + args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end + + # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + unet.to(weight_dtype) + text_encoder.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + if args.deepspeed: + if args.train_text_encoder: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder) + else: + ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet) + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_models = [ds_model] + + else: + if train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + training_models = [unet, text_encoder] + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + training_models = [unet] + + if not train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + accelerator.print( + f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + ) + accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "dreambooth" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + # For --sample_at_first + train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) + + loss_recorder = train_util.LossRecorder() + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + # 指定したステップ数までText Encoderを学習する:epoch最初の状態 + unet.train() + # train==True is required to enable gradient_checkpointing + if args.gradient_checkpointing or global_step < args.stop_text_encoder_training: + text_encoder.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + # 指定したステップ数でText Encoderの学習を止める + if global_step == args.stop_text_encoder_training: + accelerator.print(f"stop text encoder training at step {global_step}") + if not args.gradient_checkpointing: + text_encoder.train(False) + text_encoder.requires_grad_(False) + if len(training_models) == 2: + training_models = training_models[0] # remove text_encoder from training_models + + with accelerator.accumulate(*training_models): + with torch.no_grad(): + # latentに変換 + if cache_latents: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + # Get the text embedding for conditioning + with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): + if args.weighted_captions: + encoder_hidden_states = get_weighted_text_embeddings( + tokenize_strategy.tokenizer, + text_encoder, + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + input_ids = batch["input_ids_list"][0].to(accelerator.device) + encoder_hidden_states = text_encoding_strategy.encode_tokens( + tokenize_strategy, [text_encoder], [input_ids] + )[0] + if args.full_fp16: + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + if train_text_encoder: + params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters()) + else: + params_to_clip = unet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(text_encoder), + accelerator.unwrap_model(unet), + vae, + ) + + current_loss = loss.detach().item() + if args.logging_dir is not None: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + # checking for saving is in util + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end_or_stepwise( + args, + True, + accelerator, + src_path, + save_stable_diffusion_format, + use_safetensors, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.unwrap_model(text_encoder), + accelerator.unwrap_model(unet), + vae, + ) + + train_util.sample_images( + accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet + ) + + is_main_process = accelerator.is_main_process + if is_main_process: + unet = accelerator.unwrap_model(unet) + text_encoder = accelerator.unwrap_model(text_encoder) + + accelerator.end_training() + + if is_main_process and (args.save_state or args.save_state_on_train_end): + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_train_end( + args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae + ) + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, False, True) + train_util.add_training_arguments(parser, True) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_sd_saving_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + + parser.add_argument( + "--learning_rate_te", + type=float, + default=None, + help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ", + ) + parser.add_argument( + "--no_token_padding", + action="store_true", + help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)", + ) + parser.add_argument( + "--stop_text_encoder_training", + type=int, + default=None, + help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない", + ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/train_network.py b/train_network.py new file mode 100644 index 0000000000000000000000000000000000000000..6f8a3476f1a1ae253d9b8d155dfb147c7c82459c --- /dev/null +++ b/train_network.py @@ -0,0 +1,1500 @@ +import importlib +import argparse +import math +import os +import sys +import random +import time +import json +from multiprocessing import Value +from typing import Any, List +import toml + +from tqdm import tqdm +# from comfy.utils import ProgressBar + +import torch +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate.utils import set_seed +from accelerate import Accelerator +from diffusers import DDPMScheduler +from library import deepspeed_utils, model_util, strategy_base, strategy_sd + +from library import train_util as train_util +from library.train_util import DreamBoothDataset +from library import config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +from library import huggingface_util as huggingface_util +from library import custom_train_functions as custom_train_functions +from library.custom_train_functions import ( + apply_snr_weight, + get_weighted_text_embeddings, + prepare_scheduler_for_custom_training, + scale_v_prediction_loss_like_noise_prediction, + add_v_prediction_like_loss, + apply_debiased_estimation, + apply_masked_loss, +) +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class NetworkTrainer: + def __init__(self): + self.vae_scale_factor = 0.18215 + self.is_sdxl = False + + # TODO 他のスクリプトと共通化する + def generate_step_logs( + self, + args: argparse.Namespace, + current_loss, + avr_loss, + lr_scheduler, + lr_descriptions, + optimizer=None, + keys_scaled=None, + mean_norm=None, + maximum_norm=None, + ): + logs = {"loss/current": current_loss, "loss/average": avr_loss} + + if keys_scaled is not None: + logs["max_norm/keys_scaled"] = keys_scaled + logs["max_norm/average_key_norm"] = mean_norm + logs["max_norm/max_key_norm"] = maximum_norm + + lrs = lr_scheduler.get_last_lr() + for i, lr in enumerate(lrs): + if lr_descriptions is not None: + lr_desc = lr_descriptions[i] + else: + idx = i - (0 if args.network_train_unet_only else -1) + if idx == -1: + lr_desc = "textencoder" + else: + if len(lrs) > 2: + lr_desc = f"group{idx}" + else: + lr_desc = "unet" + + logs[f"lr/{lr_desc}"] = lr + + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): + # tracking d*lr value + logs[f"lr/d*lr/{lr_desc}"] = ( + lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] + ) + if ( + args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None + ): # tracking d*lr value of unet. + logs["lr/d*lr"] = ( + optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] + ) + else: + idx = 0 + if not args.network_train_unet_only: + logs["lr/textencoder"] = float(lrs[0]) + idx = 1 + + for i in range(idx, len(lrs)): + logs[f"lr/group{i}"] = float(lrs[i]) + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): + logs[f"lr/d*lr/group{i}"] = ( + lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] + ) + if ( + args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None + ): + logs[f"lr/d*lr/group{i}"] = ( + optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] + ) + + return logs + + def assert_extra_args(self, args, train_dataset_group): + train_dataset_group.verify_bucket_reso_steps(64) + + def load_target_model(self, args, weight_dtype, accelerator): + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) + + # Incorporate xformers or memory efficient attention into the model + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + if torch.__version__ >= "2.0.0": # If you have xformers compatible with PyTorch 2.0.0 or higher, you can use the following + vae.set_use_memory_efficient_attention_xformers(args.xformers) + + return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet + + def get_tokenize_strategy(self, args): + return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + + def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]: + return [tokenize_strategy.tokenizer] + + def get_latents_caching_strategy(self, args): + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check + ) + return latents_caching_strategy + + def get_text_encoding_strategy(self, args): + return strategy_sd.SdTextEncodingStrategy(args.clip_skip) + + def get_text_encoder_outputs_caching_strategy(self, args): + return None + + def get_models_for_text_encoding(self, args, accelerator, text_encoders): + """ + Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models. + FLUX.1 and SD3 may cache some outputs of the text encoder, so return the models that will be used for encoding (not cached). + """ + return text_encoders + + # returns a list of bool values indicating whether each text encoder should be trained + def get_text_encoders_train_flags(self, args, text_encoders): + return [True] * len(text_encoders) if self.is_train_text_encoder(args) else [False] * len(text_encoders) + + def is_train_text_encoder(self, args): + return not args.network_train_unet_only + + def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, text_encoders, dataset, weight_dtype): + for t_enc in text_encoders: + t_enc.to(accelerator.device, dtype=weight_dtype) + + def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype, **kwargs): + noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample + return noise_pred + + def all_reduce_network(self, accelerator, network): + for param in network.parameters(): + if param.grad is not None: + param.grad = accelerator.reduce(param.grad, reduction="mean") + + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoder, unet): + train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoder, unet) + + # region SD/SDXL + + def post_process_network(self, args, accelerator, network, text_encoders, unet): + pass + + def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + return noise_scheduler + + def encode_images_to_latents(self, args, accelerator, vae, images): + return vae.encode(images).latent_dist.sample() + + def shift_scale_latents(self, args, latents): + return latents * self.vae_scale_factor + + def get_noise_pred_and_target( + self, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, + ): + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + + # ensure the hidden state will require grad + if args.gradient_checkpointing: + for x in noisy_latents: + x.requires_grad_(True) + for t in text_encoder_conds: + t.requires_grad_(True) + + # Predict the noise residual + with accelerator.autocast(): + noise_pred = self.call_unet( + args, + accelerator, + unet, + noisy_latents.requires_grad_(train_unet), + timesteps, + text_encoder_conds, + batch, + weight_dtype, + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(), accelerator.autocast(): + noise_pred_prior = self.call_unet( + args, + accelerator, + unet, + noisy_latents, + timesteps, + text_encoder_conds, + batch, + weight_dtype, + indices=diff_output_pr_indices, + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) + + return noise_pred, target, timesteps, None + + def post_process_loss(self, loss, args, timesteps, noise_scheduler): + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) + return loss + + def get_sai_model_spec(self, args): + return train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False) + + def update_metadata(self, metadata, args): + pass + + def is_text_encoder_not_needed_for_training(self, args): + return False # use for sample images + + def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): + # set top parameter requires_grad = True for gradient checkpointing works + text_encoder.text_model.embeddings.requires_grad_(True) + + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): + text_encoder.text_model.embeddings.to(dtype=weight_dtype) + + def prepare_unet_with_accelerator( + self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module + ) -> torch.nn.Module: + return accelerator.prepare(unet) + + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + pass + + # endregion + + def init_train(self, args): + session_id = random.randint(0, 2**32) + training_started_at = time.time() + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + deepspeed_utils.prepare_deepspeed_args(args) + setup_logging(args, reset=True) + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + use_user_config = args.dataset_config is not None + + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) + + tokenize_strategy = self.get_tokenize_strategy(args) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored + + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = self.get_latents_caching_strategy(args) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + + # pbar = ProgressBar(5) + + # Prepare the dataset + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) + if use_user_config: + logger.info(f"Loading dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignoring the following options because config file is found: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + # use arbitrary dataset class + train_dataset_group = train_util.load_arbitrary_dataset(args) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + if args.debug_dataset: + train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + self.assert_extra_args(args, train_dataset_group) # may change some args + + # prepare accelerator + logger.info("preparing accelerator") + accelerator = train_util.prepare_accelerator(args) + + + # Prepare a type that supports mixed precision and cast it as appropriate. + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # Load the model + model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator) + + # text_encoder is List[CLIPTextModel] or CLIPTextModel + text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] + + # pbar.update(1) + + # Load the model for incremental learning + #sys.path.append(os.path.dirname(__file__)) + accelerator.print("import network module:", args.network_module) + package = __name__.split('.')[0] + network_module = importlib.import_module(args.network_module, package=package) + + if args.base_weights is not None: + # base_weights が指定されている場合は、指定された重みを読み込みマージする + for i, weight_path in enumerate(args.base_weights): + if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i: + multiplier = 1.0 + else: + multiplier = args.base_weights_multiplier[i] + + accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}") + + module, weights_sd = network_module.create_network_from_weights( + multiplier, weight_path, vae, text_encoder, unet, for_inference=True + ) + module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu") + + accelerator.print(f"all weights merged: {', '.join(args.base_weights)}") + + + # cache latents + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + + train_dataset_group.new_cache_latents(vae, accelerator) + + vae.to("cpu") + clean_memory_on_device(accelerator.device) + + # cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu + text_encoding_strategy = self.get_text_encoding_strategy(args) + strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + + text_encoder_outputs_caching_strategy = self.get_text_encoder_outputs_caching_strategy(args) + if text_encoder_outputs_caching_strategy is not None: + strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy) + self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype) + + # pbar.update(1) + + # prepare network + net_kwargs = {} + if args.network_args is not None: + for net_arg in args.network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value + + # if a new network is added in future, add if ~ then blocks for each network (;'∀') + if args.dim_from_weights: + network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs) + else: + if "dropout" not in net_kwargs: + # workaround for LyCORIS (;^ω^) + net_kwargs["dropout"] = args.network_dropout + + network = network_module.create_network( + 1.0, + args.network_dim, + args.network_alpha, + vae, + text_encoder, + unet, + neuron_dropout=args.network_dropout, + **net_kwargs, + ) + if network is None: + return + network_has_multiplier = hasattr(network, "set_multiplier") + + if hasattr(network, "prepare_network"): + network.prepare_network(args) + if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"): + logger.warning( + "warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません" + ) + args.scale_weight_norms = False + + self.post_process_network(args, accelerator, network, text_encoders, unet) + + # apply network to unet and text_encoder + train_unet = not args.network_train_text_encoder_only + train_text_encoder = self.is_train_text_encoder(args) + network.apply_to(text_encoder, unet, train_text_encoder, train_unet) + + if args.network_weights is not None: + # FIXME consider alpha of weights: this assumes that the alpha is not changed + info = network.load_weights(args.network_weights) + accelerator.print(f"load network weights from {args.network_weights}: {info}") + + if args.gradient_checkpointing: + if args.cpu_offload_checkpointing: + unet.enable_gradient_checkpointing(cpu_offload=True) + else: + unet.enable_gradient_checkpointing() + + for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): + if flag: + if t_enc.supports_gradient_checkpointing: + t_enc.gradient_checkpointing_enable() + del t_enc + network.enable_gradient_checkpointing() # may have no effect + + # Prepare classes necessary for learning + accelerator.print("prepare optimizer, data loader etc.") + + # make backward compatibility for text_encoder_lr + support_multiple_lrs = hasattr(network, "prepare_optimizer_params_with_multiple_te_lrs") + if support_multiple_lrs: + text_encoder_lr = args.text_encoder_lr + else: + # toml backward compatibility + if args.text_encoder_lr is None or isinstance(args.text_encoder_lr, float) or isinstance(args.text_encoder_lr, int): + text_encoder_lr = args.text_encoder_lr + else: + text_encoder_lr = None if len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0] + try: + if support_multiple_lrs: + results = network.prepare_optimizer_params_with_multiple_te_lrs(text_encoder_lr, args.unet_lr, args.learning_rate) + else: + results = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr, args.learning_rate) + if type(results) is tuple: + trainable_params = results[0] + lr_descriptions = results[1] + else: + trainable_params = results + lr_descriptions = None + except TypeError as e: + trainable_params = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr) + lr_descriptions = None + + # if len(trainable_params) == 0: + # accelerator.print("no trainable parameters found / 学習可能なパラメータが見つかりませんでした") + # for params in trainable_params: + # for k, v in params.items(): + # if type(v) == float: + # pass + # else: + # v = len(v) + # accelerator.print(f"trainable_params: {k} = {v}") + + optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) + self.optimizer_train_fn, self.optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) + + # prepare dataloader + # strategies are set here because they cannot be referenced in another process. Copy them with the dataset + # some strategies can be None + train_dataset_group.set_current_strategies() + + # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # Send learning steps to the dataset side as well + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr scheduler init + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # Experimental function: performs fp16/bf16 learning including gradients, sets the entire model to fp16/bf16 + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + network.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + network.to(weight_dtype) + + unet_weight_dtype = te_weight_dtype = weight_dtype + # Experimental Feature: Put base model into fp8 to save vram + if args.fp8_base or args.fp8_base_unet: + assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0" + assert ( + args.mixed_precision != "no" + ), "fp8_base requires mixed precision='fp16' or 'bf16'" + accelerator.print("enable fp8 training for U-Net.") + unet_weight_dtype = torch.float8_e4m3fn if args.fp8_dtype == "e4m3" else torch.float8_e5m2 + accelerator.print(f"unet_weight_dtype: {unet_weight_dtype}") + + if not args.fp8_base_unet and not args.network_train_unet_only: + accelerator.print("enable fp8 training for Text Encoder.") + te_weight_dtype = torch.float8_e4m3fn if args.fp8_dtype == "e4m3" else torch.float8_e5m2 + + # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM + # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory + + # logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}") + # unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above + logger.info(f"set U-Net weight dtype to {unet_weight_dtype}") + unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator + + unet.requires_grad_(False) + unet.to(dtype=unet_weight_dtype) + for i, t_enc in enumerate(text_encoders): + t_enc.requires_grad_(False) + + # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16 + if t_enc.device.type != "cpu": + t_enc.to(dtype=te_weight_dtype) + + # nn.Embedding not support FP8 + if te_weight_dtype != weight_dtype: + self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good + if args.deepspeed: + flags = self.get_text_encoders_train_flags(args, text_encoders) + ds_model = deepspeed_utils.prepare_deepspeed_model( + args, + unet=unet if train_unet else None, + text_encoder1=text_encoders[0] if flags[0] else None, + text_encoder2=(text_encoders[1] if flags[1] else None) if len(text_encoders) > 1 else None, + network=network, + ) + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) + training_model = ds_model + else: + if train_unet: + # default implementation is: unet = accelerator.prepare(unet) + unet = self.prepare_unet_with_accelerator(args, accelerator, unet) # accelerator does some magic here + else: + unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator + if train_text_encoder: + text_encoders = [ + (accelerator.prepare(t_enc) if flag else t_enc) + for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)) + ] + if len(text_encoders) > 1: + text_encoder = text_encoders + else: + text_encoder = text_encoders[0] + else: + pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set + + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + network, optimizer, train_dataloader, lr_scheduler + ) + training_model = network + + if args.gradient_checkpointing: + # according to TI example in Diffusers, train is required + unet.train() + for i, (t_enc, frag) in enumerate(zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))): + t_enc.train() + + # set top parameter requires_grad = True for gradient checkpointing works + if frag: + self.prepare_text_encoder_grad_ckpt_workaround(i, t_enc) + + else: + unet.eval() + for t_enc in text_encoders: + t_enc.eval() + + del t_enc + + accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet) + + if not cache_latents: # If you do not cache, VAE will be used, so enable VAE preparation. + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=vae_dtype) + + # Experimental feature: Perform fp16 learning including gradients Apply a patch to PyTorch to enable grad scale in fp16 + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # pbar.update(1) + + # before resuming make hook for saving/loading to save/load the network weights only + def save_model_hook(models, weights, output_dir): + # pop weights of other models than network to save only network weights + # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606 + #if args.deepspeed: + remove_indices = [] + for i, model in enumerate(models): + if not isinstance(model, type(accelerator.unwrap_model(network))): + remove_indices.append(i) + for i in reversed(remove_indices): + if len(weights) > i: + weights.pop(i) + # print(f"save model hook: {len(weights)} weights will be saved") + + # save current ecpoch and step + train_state_file = os.path.join(output_dir, "train_state.json") + # +1 is needed because the state is saved before current_step is set from global_step + logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}") + with open(train_state_file, "w", encoding="utf-8") as f: + json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f) + + steps_from_state = None + + def load_model_hook(models, input_dir): + # remove models except network + remove_indices = [] + for i, model in enumerate(models): + if not isinstance(model, type(accelerator.unwrap_model(network))): + remove_indices.append(i) + for i in reversed(remove_indices): + models.pop(i) + # print(f"load model hook: {len(models)} models will be loaded") + + # load current epoch and step to + nonlocal steps_from_state + train_state_file = os.path.join(input_dir, "train_state.json") + if os.path.exists(train_state_file): + with open(train_state_file, "r", encoding="utf-8") as f: + data = json.load(f) + steps_from_state = data["current_step"] + logger.info(f"load train state from {train_state_file}: {data}") + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # resume from local or huggingface + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # pbar.update(1) + + # Calculate the number of epochs + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # TODO: find a way to handle total batch size when there are multiple datasets + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + accelerator.print("running training") + accelerator.print(f" num train images * repeats: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch: {len(train_dataloader)}") + accelerator.print(f" num epochs: {num_train_epochs}") + accelerator.print( + f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + accelerator.print(f" gradient accumulation steps: {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps: {args.max_train_steps}") + + # TODO refactor metadata creation and move to util + metadata = { + "ss_session_id": session_id, # random integer indicating which group of epochs the model came from + "ss_training_started_at": training_started_at, # unix timestamp + "ss_output_name": args.output_name, + "ss_learning_rate": args.learning_rate, + "ss_text_encoder_lr": text_encoder_lr, + "ss_unet_lr": args.unet_lr, + "ss_num_train_images": train_dataset_group.num_train_images, + "ss_num_reg_images": train_dataset_group.num_reg_images, + "ss_num_batches_per_epoch": len(train_dataloader), + "ss_num_epochs": num_train_epochs, + "ss_gradient_checkpointing": args.gradient_checkpointing, + "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, + "ss_max_train_steps": args.max_train_steps, + "ss_lr_warmup_steps": args.lr_warmup_steps, + "ss_lr_scheduler": args.lr_scheduler, + "ss_network_module": args.network_module, + "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim + "ss_network_alpha": args.network_alpha, # some networks may not have alpha + "ss_network_dropout": args.network_dropout, # some networks may not have dropout + "ss_mixed_precision": args.mixed_precision, + "ss_full_fp16": bool(args.full_fp16), + "ss_v2": bool(args.v2), + "ss_base_model_version": model_version, + "ss_clip_skip": args.clip_skip, + "ss_max_token_length": args.max_token_length, + "ss_cache_latents": bool(args.cache_latents), + "ss_seed": args.seed, + "ss_lowram": args.lowram, + "ss_noise_offset": args.noise_offset, + "ss_multires_noise_iterations": args.multires_noise_iterations, + "ss_multires_noise_discount": args.multires_noise_discount, + "ss_adaptive_noise_scale": args.adaptive_noise_scale, + "ss_zero_terminal_snr": args.zero_terminal_snr, + "ss_training_comment": args.training_comment, # will not be updated after training + "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), + "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""), + "ss_max_grad_norm": args.max_grad_norm, + "ss_caption_dropout_rate": args.caption_dropout_rate, + "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs, + "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate, + "ss_face_crop_aug_range": args.face_crop_aug_range, + "ss_prior_loss_weight": args.prior_loss_weight, + "ss_min_snr_gamma": args.min_snr_gamma, + "ss_scale_weight_norms": args.scale_weight_norms, + "ss_ip_noise_gamma": args.ip_noise_gamma, + "ss_debiased_estimation": bool(args.debiased_estimation_loss), + "ss_noise_offset_random_strength": args.noise_offset_random_strength, + "ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength, + "ss_loss_type": args.loss_type, + "ss_huber_schedule": args.huber_schedule, + "ss_huber_scale": args.huber_scale, + "ss_huber_c": args.huber_c, + "ss_fp8_base": bool(args.fp8_base), + "ss_fp8_base_unet": bool(args.fp8_base_unet), + } + + self.update_metadata(metadata, args) # architecture specific metadata + + if use_user_config: + # save metadata of multiple datasets + # NOTE: pack "ss_datasets" value as json one time + # or should also pack nested collections as json? + datasets_metadata = [] + tag_frequency = {} # merge tag frequency for metadata editor + dataset_dirs_info = {} # merge subset dirs for metadata editor + + for dataset in train_dataset_group.datasets: + is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset) + dataset_metadata = { + "is_dreambooth": is_dreambooth_dataset, + "batch_size_per_device": dataset.batch_size, + "num_train_images": dataset.num_train_images, # includes repeating + "num_reg_images": dataset.num_reg_images, + "resolution": (dataset.width, dataset.height), + "enable_bucket": bool(dataset.enable_bucket), + "min_bucket_reso": dataset.min_bucket_reso, + "max_bucket_reso": dataset.max_bucket_reso, + "tag_frequency": dataset.tag_frequency, + "bucket_info": dataset.bucket_info, + } + + subsets_metadata = [] + for subset in dataset.subsets: + subset_metadata = { + "img_count": subset.img_count, + "num_repeats": subset.num_repeats, + "color_aug": bool(subset.color_aug), + "flip_aug": bool(subset.flip_aug), + "random_crop": bool(subset.random_crop), + "shuffle_caption": bool(subset.shuffle_caption), + "keep_tokens": subset.keep_tokens, + "keep_tokens_separator": subset.keep_tokens_separator, + "secondary_separator": subset.secondary_separator, + "enable_wildcard": bool(subset.enable_wildcard), + "caption_prefix": subset.caption_prefix, + "caption_suffix": subset.caption_suffix, + } + + image_dir_or_metadata_file = None + if subset.image_dir: + image_dir = os.path.basename(subset.image_dir) + subset_metadata["image_dir"] = image_dir + image_dir_or_metadata_file = image_dir + + if is_dreambooth_dataset: + subset_metadata["class_tokens"] = subset.class_tokens + subset_metadata["is_reg"] = subset.is_reg + if subset.is_reg: + image_dir_or_metadata_file = None # not merging reg dataset + else: + metadata_file = os.path.basename(subset.metadata_file) + subset_metadata["metadata_file"] = metadata_file + image_dir_or_metadata_file = metadata_file # may overwrite + + subsets_metadata.append(subset_metadata) + + # merge dataset dir: not reg subset only + # TODO update additional-network extension to show detailed dataset config from metadata + if image_dir_or_metadata_file is not None: + # datasets may have a certain dir multiple times + v = image_dir_or_metadata_file + i = 2 + while v in dataset_dirs_info: + v = image_dir_or_metadata_file + f" ({i})" + i += 1 + image_dir_or_metadata_file = v + + dataset_dirs_info[image_dir_or_metadata_file] = { + "n_repeats": subset.num_repeats, + "img_count": subset.img_count, + } + + dataset_metadata["subsets"] = subsets_metadata + datasets_metadata.append(dataset_metadata) + + # merge tag frequency: + for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items(): + # If a directory is used by multiple datasets, count only once + # Since the number of repetitions is originally specified, the number of times a tag appears in the caption does not match the number of times it is used in training. + # Therefore, it is not very meaningful to add up the number of times for multiple datasets here. + if ds_dir_name in tag_frequency: + continue + tag_frequency[ds_dir_name] = ds_freq_for_dir + + metadata["ss_datasets"] = json.dumps(datasets_metadata) + metadata["ss_tag_frequency"] = json.dumps(tag_frequency) + metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info) + else: + # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir + assert ( + len(train_dataset_group.datasets) == 1 + ), f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug." + + dataset = train_dataset_group.datasets[0] + + dataset_dirs_info = {} + reg_dataset_dirs_info = {} + if use_dreambooth_method: + for subset in dataset.subsets: + info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info + info[os.path.basename(subset.image_dir)] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count} + else: + for subset in dataset.subsets: + dataset_dirs_info[os.path.basename(subset.metadata_file)] = { + "n_repeats": subset.num_repeats, + "img_count": subset.img_count, + } + + metadata.update( + { + "ss_batch_size_per_device": args.train_batch_size, + "ss_total_batch_size": total_batch_size, + "ss_resolution": args.resolution, + "ss_color_aug": bool(args.color_aug), + "ss_flip_aug": bool(args.flip_aug), + "ss_random_crop": bool(args.random_crop), + "ss_shuffle_caption": bool(args.shuffle_caption), + "ss_enable_bucket": bool(dataset.enable_bucket), + "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale), + "ss_min_bucket_reso": dataset.min_bucket_reso, + "ss_max_bucket_reso": dataset.max_bucket_reso, + "ss_keep_tokens": args.keep_tokens, + "ss_dataset_dirs": json.dumps(dataset_dirs_info), + "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info), + "ss_tag_frequency": json.dumps(dataset.tag_frequency), + "ss_bucket_info": json.dumps(dataset.bucket_info), + } + ) + + # add extra args + if args.network_args: + metadata["ss_network_args"] = json.dumps(net_kwargs) + + # model name and hash + if args.pretrained_model_name_or_path is not None: + sd_model_name = args.pretrained_model_name_or_path + if os.path.exists(sd_model_name): + metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) + metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name) + sd_model_name = os.path.basename(sd_model_name) + metadata["ss_sd_model_name"] = sd_model_name + + if args.vae is not None: + vae_name = args.vae + if os.path.exists(vae_name): + metadata["ss_vae_hash"] = train_util.model_hash(vae_name) + metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name) + vae_name = os.path.basename(vae_name) + metadata["ss_vae_name"] = vae_name + + metadata = {k: str(v) for k, v in metadata.items()} + + # make minimum metadata for filtering + minimum_metadata = {} + for key in train_util.SS_METADATA_MINIMUM_KEYS: + if key in metadata: + minimum_metadata[key] = metadata[key] + + # calculate steps to skip when resuming or starting from a specific step + initial_step = 0 + if args.initial_epoch is not None or args.initial_step is not None: + # if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming + if steps_from_state is not None: + logger.warning( + "steps from the state is ignored because initial_step is specified" + ) + if args.initial_step is not None: + initial_step = args.initial_step + else: + # num steps per epoch is calculated by num_processes and gradient_accumulation_steps + initial_step = (args.initial_epoch - 1) * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + else: + # if initial_epoch and initial_step are not specified, steps_from_state is used when resuming + if steps_from_state is not None: + initial_step = steps_from_state + steps_from_state = None + + if initial_step > 0: + assert ( + args.max_train_steps > initial_step + ), f"max_train_steps should be greater than initial step: {args.max_train_steps} vs {initial_step}" + + epoch_to_start = 0 + if initial_step > 0: + if args.skip_until_initial_step: + # if skip_until_initial_step is specified, load data and discard it to ensure the same data is used + if not args.resume: + logger.info( + f"initial_step is specified but not resuming. lr scheduler will be started from the beginning" + ) + logger.info(f"skipping {initial_step} steps") + initial_step *= args.gradient_accumulation_steps + + # set epoch to start to make initial_step less than len(train_dataloader) + epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + else: + # if not, only epoch no is skipped for informative purpose + epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + initial_step = 0 # do not skip + + + + noise_scheduler = self.get_noise_scheduler(args, accelerator.device) + + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "network_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + self.loss_recorder = train_util.LossRecorder() + del train_dataset_group + + # pbar.update(1) + + # callback for step start + if hasattr(accelerator.unwrap_model(network), "on_step_start"): + on_step_start_for_network = accelerator.unwrap_model(network).on_step_start + else: + on_step_start_for_network = lambda *args, **kwargs: None + + # function for saving/removing + def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + metadata["ss_training_finished_at"] = str(time.time()) + metadata["ss_steps"] = str(steps) + metadata["ss_epoch"] = str(epoch_no) + + metadata_to_save = minimum_metadata if args.no_metadata else metadata + sai_metadata = self.get_sai_model_spec(args) + metadata_to_save.update(sai_metadata) + + unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + if self.is_text_encoder_not_needed_for_training(args): + logger.info("text_encoder is not needed for training. deleting to save memory.") + for t_enc in text_encoders: + del t_enc + text_encoders = [] + text_encoder = None + + # For --sample_at_first + #self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) + + self.global_step = 0 + # training loop + if initial_step > 0: # only if skip_until_initial_step is specified + for skip_epoch in range(epoch_to_start): # skip epochs + logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") + initial_step -= len(train_dataloader) + + # log device and dtype for each model + logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}") + for i, t_enc in enumerate(text_encoders): + params_itr = t_enc.parameters() + params_itr.__next__() # skip the first parameter + params_itr.__next__() # skip the second parameter. because CLIP first two parameters are embeddings + param_3rd = params_itr.__next__() + logger.info(f"text_encoder [{i}] dtype: {param_3rd.dtype}, device: {t_enc.device}") + + clean_memory_on_device(accelerator.device) + + self.epoch_to_start = epoch_to_start + self.num_train_epochs = num_train_epochs + self.accelerator = accelerator + self.network = network + self.text_encoder = text_encoder + self.unet = unet + self.vae = vae + self.tokenizers = tokenizers + self.args = args + self.train_dataloader = train_dataloader + self.initial_step = initial_step + self.current_epoch = current_epoch + self.metadata = metadata + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.save_model = save_model + self.remove_model = remove_model + # self.comfy_pbar = None + + progress_bar = tqdm(range(args.max_train_steps - initial_step), smoothing=0, disable=False, desc="steps") + + def training_loop(break_at_steps, epoch): + steps_done = 0 + + #accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + progress_bar.set_description(f"Epoch {epoch + 1}/{num_train_epochs} - steps") + + current_epoch.value = epoch + 1 + + metadata["ss_epoch"] = str(epoch + 1) + + accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) + + skipped_dataloader = None + if self.initial_step > 0: + skipped_dataloader = accelerator.skip_first_batches(train_dataloader, self.initial_step - 1) + self.initial_step = 1 + + for step, batch in enumerate(skipped_dataloader or train_dataloader): + current_step.value = self.global_step + if self.initial_step > 0: + self.initial_step -= 1 + continue + + with accelerator.accumulate(training_model): + on_step_start_for_network(text_encoder, unet) + + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + with torch.no_grad(): + # encode latents + latents = self.encode_images_to_latents(args, accelerator, vae, batch["images"].to(vae_dtype)) + latents = latents.to(dtype=weight_dtype) + + # NaN check + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + latents = self.shift_scale_latents(args, latents) + + # get multiplier for each sample + if network_has_multiplier: + multipliers = batch["network_multipliers"] + # if all multipliers are same, use single multiplier + if torch.all(multipliers == multipliers[0]): + multipliers = multipliers[0].item() + else: + raise NotImplementedError("multipliers for each sample is not supported yet") + # print(f"set multiplier: {multipliers}") + accelerator.unwrap_model(network).set_multiplier(multipliers) + + text_encoder_conds = [] + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs + + if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: + # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' + with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): + # Get the text embedding for conditioning + if args.weighted_captions: + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, + self.get_models_for_text_encoding(args, accelerator, text_encoders), + input_ids_list, + weights_list, + ) + else: + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + encoded_text_encoder_conds = text_encoding_strategy.encode_tokens( + tokenize_strategy, + self.get_models_for_text_encoding(args, accelerator, text_encoders), + input_ids, + ) + if args.full_fp16: + encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] + + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] + + # sample noise, call unet, get target + noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, + ) + + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # weight for each sample + loss = loss * loss_weights + + # min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc. + loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) + + loss = loss.mean() # No need to divide by batch_size since it's an average + + accelerator.backward(loss) + if accelerator.sync_gradients: + self.all_reduce_network(accelerator, network) # sync DDP grad manually + if args.max_grad_norm != 0.0: + params_to_clip = accelerator.unwrap_model(network).get_trainable_params() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + if args.scale_weight_norms: + keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization( + args.scale_weight_norms, accelerator.device + ) + max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} + else: + keys_scaled, mean_norm, maximum_norm = None, None, None + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + self.global_step += 1 + + current_loss = loss.detach().item() + self.loss_recorder.add(epoch=epoch, step=step, global_step=self.global_step, loss=current_loss) + avr_loss: float = self.loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if args.scale_weight_norms: + progress_bar.set_postfix(**{**max_mean_logs, **logs}) + + if len(accelerator.trackers) > 0: + logs = self.generate_step_logs( + args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm + ) + accelerator.log(logs, step=self.global_step) + + if self.global_step >= break_at_steps: + break + steps_done += 1 + # self.comfy_pbar.update(1) + + if len(accelerator.trackers) > 0: + logs = {"loss/epoch": self.loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + return steps_done + + + + return training_loop + + # metadata["ss_epoch"] = str(num_train_epochs) + # metadata["ss_training_finished_at"] = str(time.time()) + + # network = accelerator.unwrap_model(network) + + # accelerator.end_training() + + # if (args.save_state or args.save_state_on_train_end): + # train_util.save_state_on_train_end(args, accelerator) + + # ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + # save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) + + # logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, True) + train_util.add_masked_loss_arguments(parser) + deepspeed_utils.add_deepspeed_arguments(parser) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + train_util.add_dit_training_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + + parser.add_argument( + "--cpu_offload_checkpointing", + action="store_true", + help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing for U-Net or DiT, if supported" + " / 勾配チェックポイント時にテンソルをCPUにオフロードする(U-NetまたはDiTのみ、サポートされている場合)", + ) + parser.add_argument( + "--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない" + ) + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", + ) + + parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") + parser.add_argument( + "--text_encoder_lr", + type=float, + default=None, + nargs="*", + help="learning rate for Text Encoder, can be multiple / Text Encoderの学習率、複数指定可能", + ) + parser.add_argument( + "--fp8_base_unet", + action="store_true", + help="use fp8 for U-Net (or DiT), Text Encoder is fp16 or bf16" + " / U-Net(またはDiT)にfp8を使用する。Text Encoderはfp16またはbf16", + ) + + parser.add_argument( + "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み" + ) + parser.add_argument( + "--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール" + ) + parser.add_argument( + "--network_dim", + type=int, + default=None, + help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)", + ) + parser.add_argument( + "--network_alpha", + type=float, + default=1, + help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)", + ) + parser.add_argument( + "--network_dropout", + type=float, + default=None, + help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)", + ) + parser.add_argument( + "--network_args", + type=str, + default=None, + nargs="*", + help="additional arguments for network (key=value) / ネットワークへの追加の引数", + ) + parser.add_argument( + "--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する" + ) + parser.add_argument( + "--network_train_text_encoder_only", + action="store_true", + help="only training Text Encoder part / Text Encoder関連部分のみ学習する", + ) + parser.add_argument( + "--training_comment", + type=str, + default=None, + help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列", + ) + parser.add_argument( + "--dim_from_weights", + action="store_true", + help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する", + ) + parser.add_argument( + "--scale_weight_norms", + type=float, + default=None, + help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)", + ) + parser.add_argument( + "--base_weights", + type=str, + default=None, + nargs="*", + help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイル", + ) + parser.add_argument( + "--base_weights_multiplier", + type=float, + default=None, + nargs="*", + help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率", + ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) + parser.add_argument( + "--skip_until_initial_step", + action="store_true", + help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする", + ) + parser.add_argument( + "--initial_epoch", + type=int, + default=None, + help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`." + + " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる", + ) + parser.add_argument( + "--initial_step", + type=int, + default=None, + help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch." + + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする", + ) + # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") + # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") + # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + trainer = NetworkTrainer() + trainer.train(args) diff --git a/update.js b/update.js new file mode 100644 index 0000000000000000000000000000000000000000..f98498574cadd258d4e00726d831a8469d6cd8d8 --- /dev/null +++ b/update.js @@ -0,0 +1,46 @@ +module.exports = { + run: [{ + method: "shell.run", + params: { + message: "git pull" + } + }, { + method: "shell.run", + params: { + path: "sd-scripts", + message: "git pull" + } + }, { + method: "shell.run", + params: { + path: "sd-scripts", + venv: "../env", + message: [ + "uv pip install -r requirements.txt", + ] + } + }, { + method: "shell.run", + params: { + venv: "env", + message: [ + "pip uninstall -y diffusers[torch] torch torchaudio torchvision", + "uv pip install -r requirements.txt", + ] + } + }, { + method: "script.start", + params: { + uri: "torch.js", + params: { + venv: "env", + // xformers: true // uncomment this line if your project requires xformers + } + } + }, { + method: "fs.link", + params: { + venv: "env" + } + }] +}