tonyshark commited on
Commit
cc69848
·
verified ·
1 Parent(s): 1b6dc48

Upload 132 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +7 -0
  2. Dockerfile +53 -0
  3. Dockerfile.cuda12.4 +53 -0
  4. LICENSE +7 -0
  5. README.md +54 -0
  6. __init__.py +12 -0
  7. __pycache__/train_network.cpython-310.pyc +0 -0
  8. advanced.png +3 -0
  9. app-launch.sh +5 -0
  10. app.py +1119 -0
  11. datasets/1 +0 -0
  12. docker-compose.yml +28 -0
  13. fine_tune.py +560 -0
  14. flags.png +0 -0
  15. flow.gif +3 -0
  16. flux_extract_lora.py +221 -0
  17. flux_train_comfy.py +806 -0
  18. flux_train_network_comfy.py +500 -0
  19. hf_token.json +3 -0
  20. icon.png +0 -0
  21. install.js +96 -0
  22. library/__init__.py +0 -0
  23. library/__pycache__/__init__.cpython-310.pyc +0 -0
  24. library/__pycache__/config_util.cpython-310.pyc +0 -0
  25. library/__pycache__/custom_offloading_utils.cpython-310.pyc +0 -0
  26. library/__pycache__/custom_train_functions.cpython-310.pyc +0 -0
  27. library/__pycache__/deepspeed_utils.cpython-310.pyc +0 -0
  28. library/__pycache__/device_utils.cpython-310.pyc +0 -0
  29. library/__pycache__/flux_models.cpython-310.pyc +0 -0
  30. library/__pycache__/flux_train_utils.cpython-310.pyc +0 -0
  31. library/__pycache__/flux_utils.cpython-310.pyc +0 -0
  32. library/__pycache__/huggingface_util.cpython-310.pyc +0 -0
  33. library/__pycache__/model_util.cpython-310.pyc +0 -0
  34. library/__pycache__/original_unet.cpython-310.pyc +0 -0
  35. library/__pycache__/sai_model_spec.cpython-310.pyc +0 -0
  36. library/__pycache__/sd3_models.cpython-310.pyc +0 -0
  37. library/__pycache__/sd3_utils.cpython-310.pyc +0 -0
  38. library/__pycache__/strategy_base.cpython-310.pyc +0 -0
  39. library/__pycache__/strategy_sd.cpython-310.pyc +0 -0
  40. library/__pycache__/train_util.cpython-310.pyc +3 -0
  41. library/__pycache__/utils.cpython-310.pyc +0 -0
  42. library/adafactor_fused.py +138 -0
  43. library/attention_processors.py +227 -0
  44. library/config_util.py +717 -0
  45. library/custom_offloading_utils.py +227 -0
  46. library/custom_train_functions.py +556 -0
  47. library/deepspeed_utils.py +139 -0
  48. library/device_utils.py +84 -0
  49. library/flux_models.py +1060 -0
  50. library/flux_train_utils.py +585 -0
.gitattributes CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ advanced.png filter=lfs diff=lfs merge=lfs -text
37
+ flow.gif filter=lfs diff=lfs merge=lfs -text
38
+ library/__pycache__/train_util.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
39
+ publish_to_hf.png filter=lfs diff=lfs merge=lfs -text
40
+ sample.png filter=lfs diff=lfs merge=lfs -text
41
+ screenshot.png filter=lfs diff=lfs merge=lfs -text
42
+ seed.gif filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base image with CUDA 12.2
2
+ FROM nvidia/cuda:12.2.2-base-ubuntu22.04
3
+
4
+ # Install pip if not already installed
5
+ RUN apt-get update -y && apt-get install -y \
6
+ python3-pip \
7
+ python3-dev \
8
+ git \
9
+ build-essential # Install dependencies for building extensions
10
+
11
+ # Define environment variables for UID and GID and local timezone
12
+ # ENV PUID=${PUID:-1000}
13
+ # ENV PGID=${PGID:-1000}
14
+
15
+ # Create a group with the specified GID
16
+ # RUN groupadd -g "${PGID}" appuser
17
+ # Create a user with the specified UID and GID
18
+ # RUN useradd -m -s /bin/sh -u "${PUID}" -g "${PGID}" appuser
19
+
20
+ WORKDIR /app
21
+
22
+ # Get sd-scripts from kohya-ss and install them
23
+ RUN git clone -b sd3 https://github.com/kohya-ss/sd-scripts && \
24
+ cd sd-scripts && \
25
+ pip install --no-cache-dir -r ./requirements.txt
26
+
27
+ # Install main application dependencies
28
+ COPY ./requirements.txt ./requirements.txt
29
+ RUN pip install --no-cache-dir -r ./requirements.txt
30
+
31
+ # Install Torch, Torchvision, and Torchaudio for CUDA 12.2
32
+ RUN pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu122/torch_stable.html
33
+
34
+ RUN chown -R appuser:appuser /app
35
+
36
+ # delete redundant requirements.txt and sd-scripts directory within the container
37
+ RUN rm -r ./sd-scripts
38
+ RUN rm ./requirements.txt
39
+ RUN pip install --force-reinstall -v "triton==3.1.0"
40
+ #Run application as non-root
41
+ # USER appuser
42
+
43
+ # Copy fluxgym application code
44
+ COPY . ./fluxgym
45
+
46
+ EXPOSE 7860
47
+
48
+ ENV GRADIO_SERVER_NAME="0.0.0.0"
49
+
50
+ WORKDIR /app/fluxgym
51
+
52
+ # Run fluxgym Python application
53
+ CMD ["python3", "./app.py"]
Dockerfile.cuda12.4 ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base image with CUDA 12.4
2
+ FROM nvidia/cuda:12.4.1-base-ubuntu22.04
3
+
4
+ # Install pip if not already installed
5
+ RUN apt-get update -y && apt-get install -y \
6
+ python3-pip \
7
+ python3-dev \
8
+ git \
9
+ build-essential # Install dependencies for building extensions
10
+
11
+ # Define environment variables for UID and GID and local timezone
12
+ ENV PUID=${PUID:-1000}
13
+ ENV PGID=${PGID:-1000}
14
+
15
+ # Create a group with the specified GID
16
+ RUN groupadd -g "${PGID}" appuser
17
+ # Create a user with the specified UID and GID
18
+ RUN useradd -m -s /bin/sh -u "${PUID}" -g "${PGID}" appuser
19
+
20
+ WORKDIR /app
21
+
22
+ # Get sd-scripts from kohya-ss and install them
23
+ RUN git clone -b sd3 https://github.com/kohya-ss/sd-scripts && \
24
+ cd sd-scripts && \
25
+ pip install --no-cache-dir -r ./requirements.txt
26
+
27
+ # Install main application dependencies
28
+ COPY ./requirements.txt ./requirements.txt
29
+ RUN pip install --no-cache-dir -r ./requirements.txt
30
+
31
+ # Install Torch, Torchvision, and Torchaudio for CUDA 12.4
32
+ RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
33
+
34
+ RUN chown -R appuser:appuser /app
35
+
36
+ # delete redundant requirements.txt and sd-scripts directory within the container
37
+ RUN rm -r ./sd-scripts
38
+ RUN rm ./requirements.txt
39
+
40
+ #Run application as non-root
41
+ USER appuser
42
+
43
+ # Copy fluxgym application code
44
+ COPY . ./fluxgym
45
+
46
+ EXPOSE 7860
47
+
48
+ ENV GRADIO_SERVER_NAME="0.0.0.0"
49
+
50
+ WORKDIR /app/fluxgym
51
+
52
+ # Run fluxgym Python application
53
+ CMD ["python3", "./app.py"]
LICENSE ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Copyright 2024 cocktailpeanut
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4
+
5
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6
+
7
+ THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
README.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ComfyUI Flux Trainer
2
+
3
+ Wrapper for slightly modified kohya's training scripts: https://github.com/kohya-ss/sd-scripts
4
+
5
+ Including code from: https://github.com/KohakuBlueleaf/Lycoris
6
+
7
+ And https://github.com/LoganBooker/prodigy-plus-schedule-free
8
+
9
+ ## DISCLAIMER:
10
+ I have **very** little previous experience in training anything, Flux is basically first model I've been inspired to learn. Previously I've only trained AnimateDiff Motion Loras, and built similar training nodes for it.
11
+
12
+ ## DO NOT ASK ME FOR TRAINING ADVICE
13
+ I can not emphasize this enough, this repository is not for raising questions related to the training itself, that would be better done to kohya's repo. Even so keep in mind my implementation may have mistakes.
14
+
15
+ The default settings aren't necessarily any good, they are just the last (out of many) I've tried and worked for my dataset.
16
+
17
+ # THIS IS EXPERIMENTAL
18
+ Both these nodes and the underlaying implementation by kohya is work in progress and expected to change.
19
+
20
+ # Installation
21
+ 1. Clone this repo into `custom_nodes` folder.
22
+ 2. Install dependencies: `pip install -r requirements.txt`
23
+ or if you use the portable install, run this in ComfyUI_windows_portable -folder:
24
+
25
+ `python_embeded\python.exe -m pip install -r ComfyUI\custom_nodes\ComfyUI-FluxTrainer\requirements.txt`
26
+
27
+ In addition torch version 2.4.0 or higher is highly recommended.
28
+
29
+ Example workflow for LoRA training can be found in the examples folder, it utilizes additional nodes from:
30
+
31
+ https://github.com/kijai/ComfyUI-KJNodes
32
+
33
+ And some (optional) debugging nodes from:
34
+
35
+ https://github.com/rgthree/rgthree-comfy
36
+
37
+ For LoRA training the models need to be the normal fp8 or fp16 versions, also make sure the VAE is the non-diffusers version:
38
+
39
+ https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/ae.safetensors
40
+
41
+ For full model training the fp16 version of the main model needs to be used.
42
+
43
+ ## Why train in ComfyUI?
44
+ - Familiar UI (obviously only if you are a Comfy user already)
45
+ - You can use same models you use for inference
46
+ - You can use same python environment, I faced no incompabilities
47
+ - You can build workflows to compare settings etc.
48
+
49
+ Currently supports LoRA training, and untested full finetune with code from kohya's scripts: https://github.com/kohya-ss/sd-scripts
50
+
51
+ Experimental support for LyCORIS training has been added as well, using code from: https://github.com/KohakuBlueleaf/Lycoris
52
+
53
+ ![Screenshot 2024-08-21 020207](https://github.com/user-attachments/assets/1686b180-90c8-41d0-8c96-63e76ebc2475)
54
+
__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
2
+ from .nodes_sd3 import NODE_CLASS_MAPPINGS as NODE_CLASS_MAPPINGS_SD3
3
+ from .nodes_sd3 import NODE_DISPLAY_NAME_MAPPINGS as NODE_DISPLAY_NAME_MAPPINGS_SD3
4
+ from .nodes_sdxl import NODE_CLASS_MAPPINGS as NODE_CLASS_MAPPINGS_SDXL
5
+ from .nodes_sdxl import NODE_DISPLAY_NAME_MAPPINGS as NODE_DISPLAY_NAME_MAPPINGS_SDXL
6
+
7
+ NODE_CLASS_MAPPINGS.update(NODE_CLASS_MAPPINGS_SD3)
8
+ NODE_CLASS_MAPPINGS.update(NODE_CLASS_MAPPINGS_SDXL)
9
+ NODE_DISPLAY_NAME_MAPPINGS.update(NODE_DISPLAY_NAME_MAPPINGS_SD3)
10
+ NODE_DISPLAY_NAME_MAPPINGS.update(NODE_DISPLAY_NAME_MAPPINGS_SDXL)
11
+
12
+ __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
__pycache__/train_network.cpython-310.pyc ADDED
Binary file (38.9 kB). View file
 
advanced.png ADDED

Git LFS Details

  • SHA256: 15077625eb185463cc0dd383157879fe3b73ebb7305a40f5ed2af14a49bca41d
  • Pointer size: 131 Bytes
  • Size of remote file: 182 kB
app-launch.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ cd "`dirname "$0"`" || exit 1
4
+ . env/bin/activate
5
+ python app.py
app.py ADDED
@@ -0,0 +1,1119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
4
+ os.environ['GRADIO_ANALYTICS_ENABLED'] = '0'
5
+ sys.path.insert(0, os.getcwd())
6
+ sys.path.append(os.path.join(os.path.dirname(__file__), 'sd-scripts'))
7
+ import subprocess
8
+ import gradio as gr
9
+ from PIL import Image
10
+ import torch
11
+ import uuid
12
+ import shutil
13
+ import json
14
+ import yaml
15
+ from slugify import slugify
16
+ from transformers import AutoProcessor, AutoModelForCausalLM
17
+ from gradio_logsview import LogsView, LogsViewRunner
18
+ from huggingface_hub import hf_hub_download, HfApi
19
+ from library import flux_train_utils, huggingface_util
20
+ from argparse import Namespace
21
+ import train_network
22
+ import toml
23
+ import re
24
+ MAX_IMAGES = 150
25
+
26
+ with open('models.yaml', 'r') as file:
27
+ models = yaml.safe_load(file)
28
+
29
+ def readme(base_model, lora_name, instance_prompt, sample_prompts):
30
+
31
+ # model license
32
+ model_config = models[base_model]
33
+ model_file = model_config["file"]
34
+ base_model_name = model_config["base"]
35
+ license = None
36
+ license_name = None
37
+ license_link = None
38
+ license_items = []
39
+ if "license" in model_config:
40
+ license = model_config["license"]
41
+ license_items.append(f"license: {license}")
42
+ if "license_name" in model_config:
43
+ license_name = model_config["license_name"]
44
+ license_items.append(f"license_name: {license_name}")
45
+ if "license_link" in model_config:
46
+ license_link = model_config["license_link"]
47
+ license_items.append(f"license_link: {license_link}")
48
+ license_str = "\n".join(license_items)
49
+ print(f"license_items={license_items}")
50
+ print(f"license_str = {license_str}")
51
+
52
+ # tags
53
+ tags = [ "text-to-image", "flux", "lora", "diffusers", "template:sd-lora", "fluxgym" ]
54
+
55
+ # widgets
56
+ widgets = []
57
+ sample_image_paths = []
58
+ output_name = slugify(lora_name)
59
+ samples_dir = resolve_path_without_quotes(f"outputs/{output_name}/sample")
60
+ try:
61
+ for filename in os.listdir(samples_dir):
62
+ # Filename Schema: [name]_[steps]_[index]_[timestamp].png
63
+ match = re.search(r"_(\d+)_(\d+)_(\d+)\.png$", filename)
64
+ if match:
65
+ steps, index, timestamp = int(match.group(1)), int(match.group(2)), int(match.group(3))
66
+ sample_image_paths.append((steps, index, f"sample/{filename}"))
67
+
68
+ # Sort by numeric index
69
+ sample_image_paths.sort(key=lambda x: x[0], reverse=True)
70
+
71
+ final_sample_image_paths = sample_image_paths[:len(sample_prompts)]
72
+ final_sample_image_paths.sort(key=lambda x: x[1])
73
+ for i, prompt in enumerate(sample_prompts):
74
+ _, _, image_path = final_sample_image_paths[i]
75
+ widgets.append(
76
+ {
77
+ "text": prompt,
78
+ "output": {
79
+ "url": image_path
80
+ },
81
+ }
82
+ )
83
+ except:
84
+ print(f"no samples")
85
+ dtype = "torch.bfloat16"
86
+ # Construct the README content
87
+ readme_content = f"""---
88
+ tags:
89
+ {yaml.dump(tags, indent=4).strip()}
90
+ {"widget:" if os.path.isdir(samples_dir) else ""}
91
+ {yaml.dump(widgets, indent=4).strip() if widgets else ""}
92
+ base_model: {base_model_name}
93
+ {"instance_prompt: " + instance_prompt if instance_prompt else ""}
94
+ {license_str}
95
+ ---
96
+
97
+ # {lora_name}
98
+
99
+ A Flux LoRA trained on a local computer with [Fluxgym](https://github.com/cocktailpeanut/fluxgym)
100
+
101
+ <Gallery />
102
+
103
+ ## Trigger words
104
+
105
+ {"You should use `" + instance_prompt + "` to trigger the image generation." if instance_prompt else "No trigger words defined."}
106
+
107
+ ## Download model and use it with ComfyUI, AUTOMATIC1111, SD.Next, Invoke AI, Forge, etc.
108
+
109
+ Weights for this model are available in Safetensors format.
110
+
111
+ """
112
+ return readme_content
113
+
114
+ def account_hf():
115
+ try:
116
+ with open("HF_TOKEN", "r") as file:
117
+ token = file.read()
118
+ api = HfApi(token=token)
119
+ try:
120
+ account = api.whoami()
121
+ return { "token": token, "account": account['name'] }
122
+ except:
123
+ return None
124
+ except:
125
+ return None
126
+
127
+ """
128
+ hf_logout.click(fn=logout_hf, outputs=[hf_token, hf_login, hf_logout, repo_owner])
129
+ """
130
+ def logout_hf():
131
+ os.remove("HF_TOKEN")
132
+ global current_account
133
+ current_account = account_hf()
134
+ print(f"current_account={current_account}")
135
+ return gr.update(value=""), gr.update(visible=True), gr.update(visible=False), gr.update(value="", visible=False)
136
+
137
+
138
+ """
139
+ hf_login.click(fn=login_hf, inputs=[hf_token], outputs=[hf_token, hf_login, hf_logout, repo_owner])
140
+ """
141
+ def login_hf(hf_token):
142
+ api = HfApi(token=hf_token)
143
+ try:
144
+ account = api.whoami()
145
+ if account != None:
146
+ if "name" in account:
147
+ with open("HF_TOKEN", "w") as file:
148
+ file.write(hf_token)
149
+ global current_account
150
+ current_account = account_hf()
151
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(value=current_account["account"], visible=True)
152
+ return gr.update(), gr.update(), gr.update(), gr.update()
153
+ except:
154
+ print(f"incorrect hf_token")
155
+ return gr.update(), gr.update(), gr.update(), gr.update()
156
+
157
+ def upload_hf(base_model, lora_rows, repo_owner, repo_name, repo_visibility, hf_token):
158
+ src = lora_rows
159
+ repo_id = f"{repo_owner}/{repo_name}"
160
+ gr.Info(f"Uploading to Huggingface. Please Stand by...", duration=None)
161
+ args = Namespace(
162
+ huggingface_repo_id=repo_id,
163
+ huggingface_repo_type="model",
164
+ huggingface_repo_visibility=repo_visibility,
165
+ huggingface_path_in_repo="",
166
+ huggingface_token=hf_token,
167
+ async_upload=False
168
+ )
169
+ print(f"upload_hf args={args}")
170
+ huggingface_util.upload(args=args, src=src)
171
+ gr.Info(f"[Upload Complete] https://huggingface.co/{repo_id}", duration=None)
172
+
173
+ def load_captioning(uploaded_files, concept_sentence):
174
+ uploaded_images = [file for file in uploaded_files if not file.endswith('.txt')]
175
+ txt_files = [file for file in uploaded_files if file.endswith('.txt')]
176
+ txt_files_dict = {os.path.splitext(os.path.basename(txt_file))[0]: txt_file for txt_file in txt_files}
177
+ updates = []
178
+ if len(uploaded_images) <= 1:
179
+ raise gr.Error(
180
+ "Please upload at least 2 images to train your model (the ideal number with default settings is between 4-30)"
181
+ )
182
+ elif len(uploaded_images) > MAX_IMAGES:
183
+ raise gr.Error(f"For now, only {MAX_IMAGES} or less images are allowed for training")
184
+ # Update for the captioning_area
185
+ # for _ in range(3):
186
+ updates.append(gr.update(visible=True))
187
+ # Update visibility and image for each captioning row and image
188
+ for i in range(1, MAX_IMAGES + 1):
189
+ # Determine if the current row and image should be visible
190
+ visible = i <= len(uploaded_images)
191
+
192
+ # Update visibility of the captioning row
193
+ updates.append(gr.update(visible=visible))
194
+
195
+ # Update for image component - display image if available, otherwise hide
196
+ image_value = uploaded_images[i - 1] if visible else None
197
+ updates.append(gr.update(value=image_value, visible=visible))
198
+
199
+ corresponding_caption = False
200
+ if(image_value):
201
+ base_name = os.path.splitext(os.path.basename(image_value))[0]
202
+ if base_name in txt_files_dict:
203
+ with open(txt_files_dict[base_name], 'r') as file:
204
+ corresponding_caption = file.read()
205
+
206
+ # Update value of captioning area
207
+ text_value = corresponding_caption if visible and corresponding_caption else concept_sentence if visible and concept_sentence else None
208
+ updates.append(gr.update(value=text_value, visible=visible))
209
+
210
+ # Update for the sample caption area
211
+ updates.append(gr.update(visible=True))
212
+ updates.append(gr.update(visible=True))
213
+
214
+ return updates
215
+
216
+ def hide_captioning():
217
+ return gr.update(visible=False), gr.update(visible=False)
218
+
219
+ def resize_image(image_path, output_path, size):
220
+ with Image.open(image_path) as img:
221
+ width, height = img.size
222
+ if width < height:
223
+ new_width = size
224
+ new_height = int((size/width) * height)
225
+ else:
226
+ new_height = size
227
+ new_width = int((size/height) * width)
228
+ print(f"resize {image_path} : {new_width}x{new_height}")
229
+ img_resized = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
230
+ img_resized.save(output_path)
231
+
232
+ def create_dataset(destination_folder, size, *inputs):
233
+ print("Creating dataset")
234
+ images = inputs[0]
235
+ if not os.path.exists(destination_folder):
236
+ os.makedirs(destination_folder)
237
+
238
+ for index, image in enumerate(images):
239
+ # copy the images to the datasets folder
240
+ new_image_path = shutil.copy(image, destination_folder)
241
+
242
+ # if it's a caption text file skip the next bit
243
+ ext = os.path.splitext(new_image_path)[-1].lower()
244
+ if ext == '.txt':
245
+ continue
246
+
247
+ # resize the images
248
+ resize_image(new_image_path, new_image_path, size)
249
+
250
+ # copy the captions
251
+
252
+ original_caption = inputs[index + 1]
253
+
254
+ image_file_name = os.path.basename(new_image_path)
255
+ caption_file_name = os.path.splitext(image_file_name)[0] + ".txt"
256
+ caption_path = resolve_path_without_quotes(os.path.join(destination_folder, caption_file_name))
257
+ print(f"image_path={new_image_path}, caption_path = {caption_path}, original_caption={original_caption}")
258
+ # if caption_path exists, do not write
259
+ if os.path.exists(caption_path):
260
+ print(f"{caption_path} already exists. use the existing .txt file")
261
+ else:
262
+ print(f"{caption_path} create a .txt caption file")
263
+ with open(caption_path, 'w') as file:
264
+ file.write(original_caption)
265
+
266
+ print(f"destination_folder {destination_folder}")
267
+ return destination_folder
268
+
269
+
270
+ def run_captioning(images, concept_sentence, *captions):
271
+ print(f"run_captioning")
272
+ print(f"concept sentence {concept_sentence}")
273
+ print(f"captions {captions}")
274
+ #Load internally to not consume resources for training
275
+ device = "cuda" if torch.cuda.is_available() else "cpu"
276
+ print(f"device={device}")
277
+ torch_dtype = torch.float16
278
+ model = AutoModelForCausalLM.from_pretrained(
279
+ "multimodalart/Florence-2-large-no-flash-attn", torch_dtype=torch_dtype, trust_remote_code=True
280
+ ).to(device)
281
+ processor = AutoProcessor.from_pretrained("multimodalart/Florence-2-large-no-flash-attn", trust_remote_code=True)
282
+
283
+ captions = list(captions)
284
+ for i, image_path in enumerate(images):
285
+ print(captions[i])
286
+ if isinstance(image_path, str): # If image is a file path
287
+ image = Image.open(image_path).convert("RGB")
288
+
289
+ prompt = "<DETAILED_CAPTION>"
290
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
291
+ print(f"inputs {inputs}")
292
+
293
+ generated_ids = model.generate(
294
+ input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3
295
+ )
296
+ print(f"generated_ids {generated_ids}")
297
+
298
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
299
+ print(f"generated_text: {generated_text}")
300
+ parsed_answer = processor.post_process_generation(
301
+ generated_text, task=prompt, image_size=(image.width, image.height)
302
+ )
303
+ print(f"parsed_answer = {parsed_answer}")
304
+ caption_text = parsed_answer["<DETAILED_CAPTION>"].replace("The image shows ", "")
305
+ print(f"caption_text = {caption_text}, concept_sentence={concept_sentence}")
306
+ if concept_sentence:
307
+ caption_text = f"{concept_sentence} {caption_text}"
308
+ captions[i] = caption_text
309
+
310
+ yield captions
311
+ model.to("cpu")
312
+ del model
313
+ del processor
314
+ if torch.cuda.is_available():
315
+ torch.cuda.empty_cache()
316
+
317
+ def recursive_update(d, u):
318
+ for k, v in u.items():
319
+ if isinstance(v, dict) and v:
320
+ d[k] = recursive_update(d.get(k, {}), v)
321
+ else:
322
+ d[k] = v
323
+ return d
324
+
325
+ def download(base_model):
326
+ model = models[base_model]
327
+ model_file = model["file"]
328
+ repo = model["repo"]
329
+
330
+ # download unet
331
+ if base_model == "flux-dev" or base_model == "flux-schnell":
332
+ unet_folder = "models/unet"
333
+ else:
334
+ unet_folder = f"models/unet/{repo}"
335
+ unet_path = os.path.join(unet_folder, model_file)
336
+ if not os.path.exists(unet_path):
337
+ os.makedirs(unet_folder, exist_ok=True)
338
+ gr.Info(f"Downloading base model: {base_model}. Please wait. (You can check the terminal for the download progress)", duration=None)
339
+ print(f"download {base_model}")
340
+ hf_hub_download(repo_id=repo, local_dir=unet_folder, filename=model_file)
341
+
342
+ # download vae
343
+ vae_folder = "models/vae"
344
+ vae_path = os.path.join(vae_folder, "ae.sft")
345
+ if not os.path.exists(vae_path):
346
+ os.makedirs(vae_folder, exist_ok=True)
347
+ gr.Info(f"Downloading vae")
348
+ print(f"downloading ae.sft...")
349
+ hf_hub_download(repo_id="cocktailpeanut/xulf-dev", local_dir=vae_folder, filename="ae.sft")
350
+
351
+ # download clip
352
+ clip_folder = "models/clip"
353
+ clip_l_path = os.path.join(clip_folder, "clip_l.safetensors")
354
+ if not os.path.exists(clip_l_path):
355
+ os.makedirs(clip_folder, exist_ok=True)
356
+ gr.Info(f"Downloading clip...")
357
+ print(f"download clip_l.safetensors")
358
+ hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", local_dir=clip_folder, filename="clip_l.safetensors")
359
+
360
+ # download t5xxl
361
+ t5xxl_path = os.path.join(clip_folder, "t5xxl_fp16.safetensors")
362
+ if not os.path.exists(t5xxl_path):
363
+ print(f"download t5xxl_fp16.safetensors")
364
+ gr.Info(f"Downloading t5xxl...")
365
+ hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", local_dir=clip_folder, filename="t5xxl_fp16.safetensors")
366
+
367
+
368
+ def resolve_path(p):
369
+ current_dir = os.path.dirname(os.path.abspath(__file__))
370
+ norm_path = os.path.normpath(os.path.join(current_dir, p))
371
+ return f"\"{norm_path}\""
372
+ def resolve_path_without_quotes(p):
373
+ current_dir = os.path.dirname(os.path.abspath(__file__))
374
+ norm_path = os.path.normpath(os.path.join(current_dir, p))
375
+ return norm_path
376
+
377
+ def gen_sh(
378
+ base_model,
379
+ output_name,
380
+ resolution,
381
+ seed,
382
+ workers,
383
+ learning_rate,
384
+ network_dim,
385
+ max_train_epochs,
386
+ save_every_n_epochs,
387
+ timestep_sampling,
388
+ guidance_scale,
389
+ vram,
390
+ sample_prompts,
391
+ sample_every_n_steps,
392
+ *advanced_components
393
+ ):
394
+
395
+ print(f"gen_sh: network_dim:{network_dim}, max_train_epochs={max_train_epochs}, save_every_n_epochs={save_every_n_epochs}, timestep_sampling={timestep_sampling}, guidance_scale={guidance_scale}, vram={vram}, sample_prompts={sample_prompts}, sample_every_n_steps={sample_every_n_steps}")
396
+
397
+ output_dir = resolve_path(f"outputs/{output_name}")
398
+ sample_prompts_path = resolve_path(f"outputs/{output_name}/sample_prompts.txt")
399
+
400
+ line_break = "\\"
401
+ file_type = "sh"
402
+ if sys.platform == "win32":
403
+ line_break = "^"
404
+ file_type = "bat"
405
+
406
+ ############# Sample args ########################
407
+ sample = ""
408
+ if len(sample_prompts) > 0 and sample_every_n_steps > 0:
409
+ sample = f"""--sample_prompts={sample_prompts_path} --sample_every_n_steps="{sample_every_n_steps}" {line_break}"""
410
+
411
+
412
+ ############# Optimizer args ########################
413
+ # if vram == "8G":
414
+ # optimizer = f"""--optimizer_type adafactor {line_break}
415
+ # --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break}
416
+ # --split_mode {line_break}
417
+ # --network_args "train_blocks=single" {line_break}
418
+ # --lr_scheduler constant_with_warmup {line_break}
419
+ # --max_grad_norm 0.0 {line_break}"""
420
+ if vram == "16G":
421
+ # 16G VRAM
422
+ optimizer = f"""--optimizer_type adafactor {line_break}
423
+ --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break}
424
+ --lr_scheduler constant_with_warmup {line_break}
425
+ --max_grad_norm 0.0 {line_break}"""
426
+ elif vram == "12G":
427
+ # 12G VRAM
428
+ optimizer = f"""--optimizer_type adafactor {line_break}
429
+ --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break}
430
+ --split_mode {line_break}
431
+ --network_args "train_blocks=single" {line_break}
432
+ --lr_scheduler constant_with_warmup {line_break}
433
+ --max_grad_norm 0.0 {line_break}"""
434
+ else:
435
+ # 20G+ VRAM
436
+ optimizer = f"--optimizer_type adamw8bit {line_break}"
437
+
438
+
439
+ #######################################################
440
+ model_config = models[base_model]
441
+ model_file = model_config["file"]
442
+ repo = model_config["repo"]
443
+ if base_model == "flux-dev" or base_model == "flux-schnell":
444
+ model_folder = "models/unet"
445
+ else:
446
+ model_folder = f"models/unet/{repo}"
447
+ model_path = os.path.join(model_folder, model_file)
448
+ pretrained_model_path = resolve_path(model_path)
449
+
450
+ clip_path = resolve_path("models/clip/clip_l.safetensors")
451
+ t5_path = resolve_path("models/clip/t5xxl_fp16.safetensors")
452
+ ae_path = resolve_path("models/vae/ae.sft")
453
+ sh = f"""accelerate launch {line_break}
454
+ --mixed_precision bf16 {line_break}
455
+ --num_cpu_threads_per_process 1 {line_break}
456
+ sd-scripts/flux_train_network.py {line_break}
457
+ --pretrained_model_name_or_path {pretrained_model_path} {line_break}
458
+ --clip_l {clip_path} {line_break}
459
+ --t5xxl {t5_path} {line_break}
460
+ --ae {ae_path} {line_break}
461
+ --cache_latents_to_disk {line_break}
462
+ --save_model_as safetensors {line_break}
463
+ --sdpa --persistent_data_loader_workers {line_break}
464
+ --max_data_loader_n_workers {workers} {line_break}
465
+ --seed {seed} {line_break}
466
+ --gradient_checkpointing {line_break}
467
+ --mixed_precision bf16 {line_break}
468
+ --save_precision bf16 {line_break}
469
+ --network_module networks.lora_flux {line_break}
470
+ --network_dim {network_dim} {line_break}
471
+ {optimizer}{sample}
472
+ --learning_rate {learning_rate} {line_break}
473
+ --cache_text_encoder_outputs {line_break}
474
+ --cache_text_encoder_outputs_to_disk {line_break}
475
+ --fp8_base {line_break}
476
+ --highvram {line_break}
477
+ --max_train_epochs {max_train_epochs} {line_break}
478
+ --save_every_n_epochs {save_every_n_epochs} {line_break}
479
+ --dataset_config {resolve_path(f"outputs/{output_name}/dataset.toml")} {line_break}
480
+ --output_dir {output_dir} {line_break}
481
+ --output_name {output_name} {line_break}
482
+ --timestep_sampling {timestep_sampling} {line_break}
483
+ --discrete_flow_shift 3.1582 {line_break}
484
+ --model_prediction_type raw {line_break}
485
+ --guidance_scale {guidance_scale} {line_break}
486
+ --loss_type l2 {line_break}"""
487
+
488
+
489
+
490
+ ############# Advanced args ########################
491
+ global advanced_component_ids
492
+ global original_advanced_component_values
493
+
494
+ # check dirty
495
+ print(f"original_advanced_component_values = {original_advanced_component_values}")
496
+ advanced_flags = []
497
+ for i, current_value in enumerate(advanced_components):
498
+ # print(f"compare {advanced_component_ids[i]}: old={original_advanced_component_values[i]}, new={current_value}")
499
+ if original_advanced_component_values[i] != current_value:
500
+ # dirty
501
+ if current_value == True:
502
+ # Boolean
503
+ advanced_flags.append(advanced_component_ids[i])
504
+ else:
505
+ # string
506
+ advanced_flags.append(f"{advanced_component_ids[i]} {current_value}")
507
+
508
+ if len(advanced_flags) > 0:
509
+ advanced_flags_str = f" {line_break}\n ".join(advanced_flags)
510
+ sh = sh + "\n " + advanced_flags_str
511
+
512
+ return sh
513
+
514
+ def gen_toml(
515
+ dataset_folder,
516
+ resolution,
517
+ class_tokens,
518
+ num_repeats
519
+ ):
520
+ toml = f"""[general]
521
+ shuffle_caption = false
522
+ caption_extension = '.txt'
523
+ keep_tokens = 1
524
+
525
+ [[datasets]]
526
+ resolution = {resolution}
527
+ batch_size = 1
528
+ keep_tokens = 1
529
+
530
+ [[datasets.subsets]]
531
+ image_dir = '{resolve_path_without_quotes(dataset_folder)}'
532
+ class_tokens = '{class_tokens}'
533
+ num_repeats = {num_repeats}"""
534
+ return toml
535
+
536
+ def update_total_steps(max_train_epochs, num_repeats, images):
537
+ try:
538
+ num_images = len(images)
539
+ total_steps = max_train_epochs * num_images * num_repeats
540
+ print(f"max_train_epochs={max_train_epochs} num_images={num_images}, num_repeats={num_repeats}, total_steps={total_steps}")
541
+ return gr.update(value = total_steps)
542
+ except:
543
+ print("")
544
+
545
+ def set_repo(lora_rows):
546
+ selected_name = os.path.basename(lora_rows)
547
+ return gr.update(value=selected_name)
548
+
549
+ def get_loras():
550
+ try:
551
+ outputs_path = resolve_path_without_quotes(f"outputs")
552
+ files = os.listdir(outputs_path)
553
+ folders = [os.path.join(outputs_path, item) for item in files if os.path.isdir(os.path.join(outputs_path, item)) and item != "sample"]
554
+ folders.sort(key=lambda file: os.path.getctime(file), reverse=True)
555
+ return folders
556
+ except Exception as e:
557
+ return []
558
+
559
+ def get_samples(lora_name):
560
+ output_name = slugify(lora_name)
561
+ try:
562
+ samples_path = resolve_path_without_quotes(f"outputs/{output_name}/sample")
563
+ files = [os.path.join(samples_path, file) for file in os.listdir(samples_path)]
564
+ files.sort(key=lambda file: os.path.getctime(file), reverse=True)
565
+ return files
566
+ except:
567
+ return []
568
+
569
+ def start_training(
570
+ base_model,
571
+ lora_name,
572
+ train_script,
573
+ train_config,
574
+ sample_prompts,
575
+ ):
576
+ # write custom script and toml
577
+ if not os.path.exists("models"):
578
+ os.makedirs("models", exist_ok=True)
579
+ if not os.path.exists("outputs"):
580
+ os.makedirs("outputs", exist_ok=True)
581
+ output_name = slugify(lora_name)
582
+ output_dir = resolve_path_without_quotes(f"outputs/{output_name}")
583
+ if not os.path.exists(output_dir):
584
+ os.makedirs(output_dir, exist_ok=True)
585
+
586
+ download(base_model)
587
+
588
+ file_type = "sh"
589
+ if sys.platform == "win32":
590
+ file_type = "bat"
591
+
592
+ sh_filename = f"train.{file_type}"
593
+ sh_filepath = resolve_path_without_quotes(f"outputs/{output_name}/{sh_filename}")
594
+ with open(sh_filepath, 'w', encoding="utf-8") as file:
595
+ file.write(train_script)
596
+ gr.Info(f"Generated train script at {sh_filename}")
597
+
598
+
599
+ dataset_path = resolve_path_without_quotes(f"outputs/{output_name}/dataset.toml")
600
+ with open(dataset_path, 'w', encoding="utf-8") as file:
601
+ file.write(train_config)
602
+ gr.Info(f"Generated dataset.toml")
603
+
604
+ sample_prompts_path = resolve_path_without_quotes(f"outputs/{output_name}/sample_prompts.txt")
605
+ with open(sample_prompts_path, 'w', encoding='utf-8') as file:
606
+ file.write(sample_prompts)
607
+ gr.Info(f"Generated sample_prompts.txt")
608
+
609
+ # Train
610
+ if sys.platform == "win32":
611
+ command = sh_filepath
612
+ else:
613
+ command = f"bash \"{sh_filepath}\""
614
+
615
+ # Use Popen to run the command and capture output in real-time
616
+ env = os.environ.copy()
617
+ env['PYTHONIOENCODING'] = 'utf-8'
618
+ env['LOG_LEVEL'] = 'DEBUG'
619
+ runner = LogsViewRunner()
620
+ cwd = os.path.dirname(os.path.abspath(__file__))
621
+ gr.Info(f"Started training")
622
+ yield from runner.run_command([command], cwd=cwd)
623
+ yield runner.log(f"Runner: {runner}")
624
+
625
+ # Generate Readme
626
+ config = toml.loads(train_config)
627
+ concept_sentence = config['datasets'][0]['subsets'][0]['class_tokens']
628
+ print(f"concept_sentence={concept_sentence}")
629
+ print(f"lora_name {lora_name}, concept_sentence={concept_sentence}, output_name={output_name}")
630
+ sample_prompts_path = resolve_path_without_quotes(f"outputs/{output_name}/sample_prompts.txt")
631
+ with open(sample_prompts_path, "r", encoding="utf-8") as f:
632
+ lines = f.readlines()
633
+ sample_prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
634
+ md = readme(base_model, lora_name, concept_sentence, sample_prompts)
635
+ readme_path = resolve_path_without_quotes(f"outputs/{output_name}/README.md")
636
+ with open(readme_path, "w", encoding="utf-8") as f:
637
+ f.write(md)
638
+
639
+ gr.Info(f"Training Complete. Check the outputs folder for the LoRA files.", duration=None)
640
+
641
+
642
+ def update(
643
+ base_model,
644
+ lora_name,
645
+ resolution,
646
+ seed,
647
+ workers,
648
+ class_tokens,
649
+ learning_rate,
650
+ network_dim,
651
+ max_train_epochs,
652
+ save_every_n_epochs,
653
+ timestep_sampling,
654
+ guidance_scale,
655
+ vram,
656
+ num_repeats,
657
+ sample_prompts,
658
+ sample_every_n_steps,
659
+ *advanced_components,
660
+ ):
661
+ output_name = slugify(lora_name)
662
+ dataset_folder = str(f"datasets/{output_name}")
663
+ sh = gen_sh(
664
+ base_model,
665
+ output_name,
666
+ resolution,
667
+ seed,
668
+ workers,
669
+ learning_rate,
670
+ network_dim,
671
+ max_train_epochs,
672
+ save_every_n_epochs,
673
+ timestep_sampling,
674
+ guidance_scale,
675
+ vram,
676
+ sample_prompts,
677
+ sample_every_n_steps,
678
+ *advanced_components,
679
+ )
680
+ toml = gen_toml(
681
+ dataset_folder,
682
+ resolution,
683
+ class_tokens,
684
+ num_repeats
685
+ )
686
+ return gr.update(value=sh), gr.update(value=toml), dataset_folder
687
+
688
+ """
689
+ demo.load(fn=loaded, js=js, outputs=[hf_token, hf_login, hf_logout, hf_account])
690
+ """
691
+ def loaded():
692
+ global current_account
693
+ current_account = account_hf()
694
+ print(f"current_account={current_account}")
695
+ if current_account != None:
696
+ return gr.update(value=current_account["token"]), gr.update(visible=False), gr.update(visible=True), gr.update(value=current_account["account"], visible=True)
697
+ else:
698
+ return gr.update(value=""), gr.update(visible=True), gr.update(visible=False), gr.update(value="", visible=False)
699
+
700
+ def update_sample(concept_sentence):
701
+ return gr.update(value=concept_sentence)
702
+
703
+ def refresh_publish_tab():
704
+ loras = get_loras()
705
+ return gr.Dropdown(label="Trained LoRAs", choices=loras)
706
+
707
+ def init_advanced():
708
+ # if basic_args
709
+ basic_args = {
710
+ 'pretrained_model_name_or_path',
711
+ 'clip_l',
712
+ 't5xxl',
713
+ 'ae',
714
+ 'cache_latents_to_disk',
715
+ 'save_model_as',
716
+ 'sdpa',
717
+ 'persistent_data_loader_workers',
718
+ 'max_data_loader_n_workers',
719
+ 'seed',
720
+ 'gradient_checkpointing',
721
+ 'mixed_precision',
722
+ 'save_precision',
723
+ 'network_module',
724
+ 'network_dim',
725
+ 'learning_rate',
726
+ 'cache_text_encoder_outputs',
727
+ 'cache_text_encoder_outputs_to_disk',
728
+ 'fp8_base',
729
+ 'highvram',
730
+ 'max_train_epochs',
731
+ 'save_every_n_epochs',
732
+ 'dataset_config',
733
+ 'output_dir',
734
+ 'output_name',
735
+ 'timestep_sampling',
736
+ 'discrete_flow_shift',
737
+ 'model_prediction_type',
738
+ 'guidance_scale',
739
+ 'loss_type',
740
+ 'optimizer_type',
741
+ 'optimizer_args',
742
+ 'lr_scheduler',
743
+ 'sample_prompts',
744
+ 'sample_every_n_steps',
745
+ 'max_grad_norm',
746
+ 'split_mode',
747
+ 'network_args'
748
+ }
749
+
750
+ # generate a UI config
751
+ # if not in basic_args, create a simple form
752
+ parser = train_network.setup_parser()
753
+ flux_train_utils.add_flux_train_arguments(parser)
754
+ args_info = {}
755
+ for action in parser._actions:
756
+ if action.dest != 'help': # Skip the default help argument
757
+ # if the dest is included in basic_args
758
+ args_info[action.dest] = {
759
+ "action": action.option_strings, # Option strings like '--use_8bit_adam'
760
+ "type": action.type, # Type of the argument
761
+ "help": action.help, # Help message
762
+ "default": action.default, # Default value, if any
763
+ "required": action.required # Whether the argument is required
764
+ }
765
+ temp = []
766
+ for key in args_info:
767
+ temp.append({ 'key': key, 'action': args_info[key] })
768
+ temp.sort(key=lambda x: x['key'])
769
+ advanced_component_ids = []
770
+ advanced_components = []
771
+ for item in temp:
772
+ key = item['key']
773
+ action = item['action']
774
+ if key in basic_args:
775
+ print("")
776
+ else:
777
+ action_type = str(action['type'])
778
+ component = None
779
+ with gr.Column(min_width=300):
780
+ if action_type == "None":
781
+ # radio
782
+ component = gr.Checkbox()
783
+ # elif action_type == "<class 'str'>":
784
+ # component = gr.Textbox()
785
+ # elif action_type == "<class 'int'>":
786
+ # component = gr.Number(precision=0)
787
+ # elif action_type == "<class 'float'>":
788
+ # component = gr.Number()
789
+ # elif "int_or_float" in action_type:
790
+ # component = gr.Number()
791
+ else:
792
+ component = gr.Textbox(value="")
793
+ if component != None:
794
+ component.interactive = True
795
+ component.elem_id = action['action'][0]
796
+ component.label = component.elem_id
797
+ component.elem_classes = ["advanced"]
798
+ if action['help'] != None:
799
+ component.info = action['help']
800
+ advanced_components.append(component)
801
+ advanced_component_ids.append(component.elem_id)
802
+ return advanced_components, advanced_component_ids
803
+
804
+
805
+ theme = gr.themes.Monochrome(
806
+ text_size=gr.themes.Size(lg="18px", md="15px", sm="13px", xl="22px", xs="12px", xxl="24px", xxs="9px"),
807
+ font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif"],
808
+ )
809
+ css = """
810
+ @keyframes rotate {
811
+ 0% {
812
+ transform: rotate(0deg);
813
+ }
814
+ 100% {
815
+ transform: rotate(360deg);
816
+ }
817
+ }
818
+ #advanced_options .advanced:nth-child(even) { background: rgba(0,0,100,0.04) !important; }
819
+ h1{font-family: georgia; font-style: italic; font-weight: bold; font-size: 30px; letter-spacing: -1px;}
820
+ h3{margin-top: 0}
821
+ .tabitem{border: 0px}
822
+ .group_padding{}
823
+ nav{position: fixed; top: 0; left: 0; right: 0; z-index: 1000; text-align: center; padding: 10px; box-sizing: border-box; display: flex; align-items: center; backdrop-filter: blur(10px); }
824
+ nav button { background: none; color: firebrick; font-weight: bold; border: 2px solid firebrick; padding: 5px 10px; border-radius: 5px; font-size: 14px; }
825
+ nav img { height: 40px; width: 40px; border-radius: 40px; }
826
+ nav img.rotate { animation: rotate 2s linear infinite; }
827
+ .flexible { flex-grow: 1; }
828
+ .tast-details { margin: 10px 0 !important; }
829
+ .toast-wrap { bottom: var(--size-4) !important; top: auto !important; border: none !important; backdrop-filter: blur(10px); }
830
+ .toast-title, .toast-text, .toast-icon, .toast-close { color: black !important; font-size: 14px; }
831
+ .toast-body { border: none !important; }
832
+ #terminal { box-shadow: none !important; margin-bottom: 25px; background: rgba(0,0,0,0.03); }
833
+ #terminal .generating { border: none !important; }
834
+ #terminal label { position: absolute !important; }
835
+ .tabs { margin-top: 50px; }
836
+ .hidden { display: none !important; }
837
+ .codemirror-wrapper .cm-line { font-size: 12px !important; }
838
+ label { font-weight: bold !important; }
839
+ #start_training.clicked { background: silver; color: black; }
840
+ """
841
+
842
+ js = """
843
+ function() {
844
+ let autoscroll = document.querySelector("#autoscroll")
845
+ if (window.iidxx) {
846
+ window.clearInterval(window.iidxx);
847
+ }
848
+ window.iidxx = window.setInterval(function() {
849
+ let text=document.querySelector(".codemirror-wrapper .cm-line").innerText.trim()
850
+ let img = document.querySelector("#logo")
851
+ if (text.length > 0) {
852
+ autoscroll.classList.remove("hidden")
853
+ if (autoscroll.classList.contains("on")) {
854
+ autoscroll.textContent = "Autoscroll ON"
855
+ window.scrollTo(0, document.body.scrollHeight, { behavior: "smooth" });
856
+ img.classList.add("rotate")
857
+ } else {
858
+ autoscroll.textContent = "Autoscroll OFF"
859
+ img.classList.remove("rotate")
860
+ }
861
+ }
862
+ }, 500);
863
+ console.log("autoscroll", autoscroll)
864
+ autoscroll.addEventListener("click", (e) => {
865
+ autoscroll.classList.toggle("on")
866
+ })
867
+ function debounce(fn, delay) {
868
+ let timeoutId;
869
+ return function(...args) {
870
+ clearTimeout(timeoutId);
871
+ timeoutId = setTimeout(() => fn(...args), delay);
872
+ };
873
+ }
874
+
875
+ function handleClick() {
876
+ console.log("refresh")
877
+ document.querySelector("#refresh").click();
878
+ }
879
+ const debouncedClick = debounce(handleClick, 1000);
880
+ document.addEventListener("input", debouncedClick);
881
+
882
+ document.querySelector("#start_training").addEventListener("click", (e) => {
883
+ e.target.classList.add("clicked")
884
+ e.target.innerHTML = "Training..."
885
+ })
886
+
887
+ }
888
+ """
889
+
890
+ current_account = account_hf()
891
+ print(f"current_account={current_account}")
892
+
893
+ with gr.Blocks(elem_id="app", theme=theme, css=css, fill_width=True) as demo:
894
+ with gr.Tabs() as tabs:
895
+ with gr.TabItem("Gym"):
896
+ output_components = []
897
+ with gr.Row():
898
+ gr.HTML("""<nav>
899
+ <img id='logo' src='/file=icon.png' width='80' height='80'>
900
+ <div class='flexible'></div>
901
+ <button id='autoscroll' class='on hidden'></button>
902
+ </nav>
903
+ """)
904
+ with gr.Row(elem_id='container'):
905
+ with gr.Column():
906
+ gr.Markdown(
907
+ """# Step 1. LoRA Info
908
+ <p style="margin-top:0">Configure your LoRA train settings.</p>
909
+ """, elem_classes="group_padding")
910
+ lora_name = gr.Textbox(
911
+ label="The name of your LoRA",
912
+ info="This has to be a unique name",
913
+ placeholder="e.g.: Persian Miniature Painting style, Cat Toy",
914
+ )
915
+ concept_sentence = gr.Textbox(
916
+ elem_id="--concept_sentence",
917
+ label="Trigger word/sentence",
918
+ info="Trigger word or sentence to be used",
919
+ placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'",
920
+ interactive=True,
921
+ )
922
+ model_names = list(models.keys())
923
+ print(f"model_names={model_names}")
924
+ base_model = gr.Dropdown(label="Base model (edit the models.yaml file to add more to this list)", choices=model_names, value=model_names[0])
925
+ vram = gr.Radio(["20G", "16G", "12G" ], value="20G", label="VRAM", interactive=True)
926
+ num_repeats = gr.Number(value=10, precision=0, label="Repeat trains per image", interactive=True)
927
+ max_train_epochs = gr.Number(label="Max Train Epochs", value=16, interactive=True)
928
+ total_steps = gr.Number(0, interactive=False, label="Expected training steps")
929
+ sample_prompts = gr.Textbox("", lines=5, label="Sample Image Prompts (Separate with new lines)", interactive=True)
930
+ sample_every_n_steps = gr.Number(0, precision=0, label="Sample Image Every N Steps", interactive=True)
931
+ resolution = gr.Number(value=512, precision=0, label="Resize dataset images")
932
+ with gr.Column():
933
+ gr.Markdown(
934
+ """# Step 2. Dataset
935
+ <p style="margin-top:0">Make sure the captions include the trigger word.</p>
936
+ """, elem_classes="group_padding")
937
+ with gr.Group():
938
+ images = gr.File(
939
+ file_types=["image", ".txt"],
940
+ label="Upload your images",
941
+ #info="If you want, you can also manually upload caption files that match the image names (example: img0.png => img0.txt)",
942
+ file_count="multiple",
943
+ interactive=True,
944
+ visible=True,
945
+ scale=1,
946
+ )
947
+ with gr.Group(visible=False) as captioning_area:
948
+ do_captioning = gr.Button("Add AI captions with Florence-2")
949
+ output_components.append(captioning_area)
950
+ #output_components = [captioning_area]
951
+ caption_list = []
952
+ for i in range(1, MAX_IMAGES + 1):
953
+ locals()[f"captioning_row_{i}"] = gr.Row(visible=False)
954
+ with locals()[f"captioning_row_{i}"]:
955
+ locals()[f"image_{i}"] = gr.Image(
956
+ type="filepath",
957
+ width=111,
958
+ height=111,
959
+ min_width=111,
960
+ interactive=False,
961
+ scale=2,
962
+ show_label=False,
963
+ show_share_button=False,
964
+ show_download_button=False,
965
+ )
966
+ locals()[f"caption_{i}"] = gr.Textbox(
967
+ label=f"Caption {i}", scale=15, interactive=True
968
+ )
969
+
970
+ output_components.append(locals()[f"captioning_row_{i}"])
971
+ output_components.append(locals()[f"image_{i}"])
972
+ output_components.append(locals()[f"caption_{i}"])
973
+ caption_list.append(locals()[f"caption_{i}"])
974
+ with gr.Column():
975
+ gr.Markdown(
976
+ """# Step 3. Train
977
+ <p style="margin-top:0">Press start to start training.</p>
978
+ """, elem_classes="group_padding")
979
+ refresh = gr.Button("Refresh", elem_id="refresh", visible=False)
980
+ start = gr.Button("Start training", visible=False, elem_id="start_training")
981
+ output_components.append(start)
982
+ train_script = gr.Textbox(label="Train script", max_lines=100, interactive=True)
983
+ train_config = gr.Textbox(label="Train config", max_lines=100, interactive=True)
984
+ with gr.Accordion("Advanced options", elem_id='advanced_options', open=False):
985
+ with gr.Row():
986
+ with gr.Column(min_width=300):
987
+ seed = gr.Number(label="--seed", info="Seed", value=42, interactive=True)
988
+ with gr.Column(min_width=300):
989
+ workers = gr.Number(label="--max_data_loader_n_workers", info="Number of Workers", value=2, interactive=True)
990
+ with gr.Column(min_width=300):
991
+ learning_rate = gr.Textbox(label="--learning_rate", info="Learning Rate", value="8e-4", interactive=True)
992
+ with gr.Column(min_width=300):
993
+ save_every_n_epochs = gr.Number(label="--save_every_n_epochs", info="Save every N epochs", value=4, interactive=True)
994
+ with gr.Column(min_width=300):
995
+ guidance_scale = gr.Number(label="--guidance_scale", info="Guidance Scale", value=1.0, interactive=True)
996
+ with gr.Column(min_width=300):
997
+ timestep_sampling = gr.Textbox(label="--timestep_sampling", info="Timestep Sampling", value="shift", interactive=True)
998
+ with gr.Column(min_width=300):
999
+ network_dim = gr.Number(label="--network_dim", info="LoRA Rank", value=4, minimum=4, maximum=128, step=4, interactive=True)
1000
+ advanced_components, advanced_component_ids = init_advanced()
1001
+ with gr.Row():
1002
+ terminal = LogsView(label="Train log", elem_id="terminal")
1003
+ with gr.Row():
1004
+ gallery = gr.Gallery(get_samples, inputs=[lora_name], label="Samples", every=10, columns=6)
1005
+
1006
+ with gr.TabItem("Publish") as publish_tab:
1007
+ hf_token = gr.Textbox(label="Huggingface Token")
1008
+ hf_login = gr.Button("Login")
1009
+ hf_logout = gr.Button("Logout")
1010
+ with gr.Row() as row:
1011
+ gr.Markdown("**LoRA**")
1012
+ gr.Markdown("**Upload**")
1013
+ loras = get_loras()
1014
+ with gr.Row():
1015
+ lora_rows = refresh_publish_tab()
1016
+ with gr.Column():
1017
+ with gr.Row():
1018
+ repo_owner = gr.Textbox(label="Account", interactive=False)
1019
+ repo_name = gr.Textbox(label="Repository Name")
1020
+ repo_visibility = gr.Textbox(label="Repository Visibility ('public' or 'private')", value="public")
1021
+ upload_button = gr.Button("Upload to HuggingFace")
1022
+ upload_button.click(
1023
+ fn=upload_hf,
1024
+ inputs=[
1025
+ base_model,
1026
+ lora_rows,
1027
+ repo_owner,
1028
+ repo_name,
1029
+ repo_visibility,
1030
+ hf_token,
1031
+ ]
1032
+ )
1033
+ hf_login.click(fn=login_hf, inputs=[hf_token], outputs=[hf_token, hf_login, hf_logout, repo_owner])
1034
+ hf_logout.click(fn=logout_hf, outputs=[hf_token, hf_login, hf_logout, repo_owner])
1035
+
1036
+
1037
+ publish_tab.select(refresh_publish_tab, outputs=lora_rows)
1038
+ lora_rows.select(fn=set_repo, inputs=[lora_rows], outputs=[repo_name])
1039
+
1040
+ dataset_folder = gr.State()
1041
+
1042
+ listeners = [
1043
+ base_model,
1044
+ lora_name,
1045
+ resolution,
1046
+ seed,
1047
+ workers,
1048
+ concept_sentence,
1049
+ learning_rate,
1050
+ network_dim,
1051
+ max_train_epochs,
1052
+ save_every_n_epochs,
1053
+ timestep_sampling,
1054
+ guidance_scale,
1055
+ vram,
1056
+ num_repeats,
1057
+ sample_prompts,
1058
+ sample_every_n_steps,
1059
+ *advanced_components
1060
+ ]
1061
+ advanced_component_ids = [x.elem_id for x in advanced_components]
1062
+ original_advanced_component_values = [comp.value for comp in advanced_components]
1063
+ images.upload(
1064
+ load_captioning,
1065
+ inputs=[images, concept_sentence],
1066
+ outputs=output_components
1067
+ )
1068
+ images.delete(
1069
+ load_captioning,
1070
+ inputs=[images, concept_sentence],
1071
+ outputs=output_components
1072
+ )
1073
+ images.clear(
1074
+ hide_captioning,
1075
+ outputs=[captioning_area, start]
1076
+ )
1077
+ max_train_epochs.change(
1078
+ fn=update_total_steps,
1079
+ inputs=[max_train_epochs, num_repeats, images],
1080
+ outputs=[total_steps]
1081
+ )
1082
+ num_repeats.change(
1083
+ fn=update_total_steps,
1084
+ inputs=[max_train_epochs, num_repeats, images],
1085
+ outputs=[total_steps]
1086
+ )
1087
+ images.upload(
1088
+ fn=update_total_steps,
1089
+ inputs=[max_train_epochs, num_repeats, images],
1090
+ outputs=[total_steps]
1091
+ )
1092
+ images.delete(
1093
+ fn=update_total_steps,
1094
+ inputs=[max_train_epochs, num_repeats, images],
1095
+ outputs=[total_steps]
1096
+ )
1097
+ images.clear(
1098
+ fn=update_total_steps,
1099
+ inputs=[max_train_epochs, num_repeats, images],
1100
+ outputs=[total_steps]
1101
+ )
1102
+ concept_sentence.change(fn=update_sample, inputs=[concept_sentence], outputs=sample_prompts)
1103
+ start.click(fn=create_dataset, inputs=[dataset_folder, resolution, images] + caption_list, outputs=dataset_folder).then(
1104
+ fn=start_training,
1105
+ inputs=[
1106
+ base_model,
1107
+ lora_name,
1108
+ train_script,
1109
+ train_config,
1110
+ sample_prompts,
1111
+ ],
1112
+ outputs=terminal,
1113
+ )
1114
+ do_captioning.click(fn=run_captioning, inputs=[images, concept_sentence] + caption_list, outputs=caption_list)
1115
+ demo.load(fn=loaded, js=js, outputs=[hf_token, hf_login, hf_logout, repo_owner])
1116
+ refresh.click(update, inputs=listeners, outputs=[train_script, train_config, dataset_folder])
1117
+ if __name__ == "__main__":
1118
+ cwd = os.path.dirname(os.path.abspath(__file__))
1119
+ demo.launch(debug=True, show_error=True, allowed_paths=[cwd])
datasets/1 ADDED
File without changes
docker-compose.yml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+
3
+ fluxgym:
4
+ build:
5
+ context: .
6
+ # change the dockerfile to Dockerfile.cuda12.4 if you are running CUDA 12.4 drivers otherwise leave as is
7
+ dockerfile: Dockerfile
8
+ image: fluxgym
9
+ container_name: fluxgym
10
+ ports:
11
+ - 7860:7860
12
+ environment:
13
+ - PUID=${PUID:-1000}
14
+ - PGID=${PGID:-1000}
15
+ volumes:
16
+ - /etc/localtime:/etc/localtime:ro
17
+ - /etc/timezone:/etc/timezone:ro
18
+ - ./:/app/fluxgym
19
+ stop_signal: SIGKILL
20
+ tty: true
21
+ deploy:
22
+ resources:
23
+ reservations:
24
+ devices:
25
+ - driver: nvidia
26
+ count: all
27
+ capabilities: [gpu]
28
+ restart: unless-stopped
fine_tune.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training with captions
2
+ # XXX dropped option: hypernetwork training
3
+
4
+ import argparse
5
+ import math
6
+ import os
7
+ from multiprocessing import Value
8
+ import toml
9
+
10
+ from tqdm import tqdm
11
+
12
+ import torch
13
+ from library import deepspeed_utils, strategy_base
14
+ from library.device_utils import init_ipex, clean_memory_on_device
15
+
16
+ init_ipex()
17
+
18
+ from accelerate.utils import set_seed
19
+ from diffusers import DDPMScheduler
20
+
21
+ from .utils import setup_logging, add_logging_arguments
22
+
23
+ setup_logging()
24
+ import logging
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ import library.train_util as train_util
29
+ import library.config_util as config_util
30
+ from library.config_util import (
31
+ ConfigSanitizer,
32
+ BlueprintGenerator,
33
+ )
34
+ import library.custom_train_functions as custom_train_functions
35
+ from library.custom_train_functions import (
36
+ apply_snr_weight,
37
+ get_weighted_text_embeddings,
38
+ prepare_scheduler_for_custom_training,
39
+ scale_v_prediction_loss_like_noise_prediction,
40
+ apply_debiased_estimation,
41
+ )
42
+ import library.strategy_sd as strategy_sd
43
+
44
+
45
+ def train(args):
46
+ train_util.verify_training_args(args)
47
+ train_util.prepare_dataset_args(args, True)
48
+ deepspeed_utils.prepare_deepspeed_args(args)
49
+ setup_logging(args, reset=True)
50
+
51
+ cache_latents = args.cache_latents
52
+
53
+ if args.seed is not None:
54
+ set_seed(args.seed) # 乱数系列を初期化する
55
+
56
+ tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
57
+ strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
58
+
59
+ # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
60
+ if cache_latents:
61
+ latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
62
+ False, args.cache_latents_to_disk, args.vae_batch_size, False
63
+ )
64
+ strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
65
+
66
+ # データセットを準備する
67
+ if args.dataset_class is None:
68
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True))
69
+ if args.dataset_config is not None:
70
+ logger.info(f"Load dataset config from {args.dataset_config}")
71
+ user_config = config_util.load_user_config(args.dataset_config)
72
+ ignored = ["train_data_dir", "in_json"]
73
+ if any(getattr(args, attr) is not None for attr in ignored):
74
+ logger.warning(
75
+ "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
76
+ ", ".join(ignored)
77
+ )
78
+ )
79
+ else:
80
+ user_config = {
81
+ "datasets": [
82
+ {
83
+ "subsets": [
84
+ {
85
+ "image_dir": args.train_data_dir,
86
+ "metadata_file": args.in_json,
87
+ }
88
+ ]
89
+ }
90
+ ]
91
+ }
92
+
93
+ blueprint = blueprint_generator.generate(user_config, args)
94
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
95
+ else:
96
+ train_dataset_group = train_util.load_arbitrary_dataset(args)
97
+
98
+ current_epoch = Value("i", 0)
99
+ current_step = Value("i", 0)
100
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
101
+ collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
102
+
103
+ if args.debug_dataset:
104
+ train_util.debug_dataset(train_dataset_group)
105
+ return
106
+ if len(train_dataset_group) == 0:
107
+ logger.error(
108
+ "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
109
+ )
110
+ return
111
+
112
+ if cache_latents:
113
+ assert (
114
+ train_dataset_group.is_latent_cacheable()
115
+ ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
116
+
117
+ # acceleratorを準備する
118
+ logger.info("prepare accelerator")
119
+ accelerator = train_util.prepare_accelerator(args)
120
+
121
+ # mixed precisionに対応した型を用意しておき適宜castする
122
+ weight_dtype, save_dtype = train_util.prepare_dtype(args)
123
+ vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
124
+
125
+ # モデルを読み込む
126
+ text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
127
+
128
+ # verify load/save model formats
129
+ if load_stable_diffusion_format:
130
+ src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
131
+ src_diffusers_model_path = None
132
+ else:
133
+ src_stable_diffusion_ckpt = None
134
+ src_diffusers_model_path = args.pretrained_model_name_or_path
135
+
136
+ if args.save_model_as is None:
137
+ save_stable_diffusion_format = load_stable_diffusion_format
138
+ use_safetensors = args.use_safetensors
139
+ else:
140
+ save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
141
+ use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
142
+
143
+ # Diffusers版のxformers使用フラグを設定する関数
144
+ def set_diffusers_xformers_flag(model, valid):
145
+ # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう
146
+ # pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`)
147
+ # U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか
148
+ # 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^)
149
+
150
+ # Recursively walk through all the children.
151
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
152
+ # gets the message
153
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
154
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
155
+ module.set_use_memory_efficient_attention_xformers(valid)
156
+
157
+ for child in module.children():
158
+ fn_recursive_set_mem_eff(child)
159
+
160
+ fn_recursive_set_mem_eff(model)
161
+
162
+ # モデルに xformers とか memory efficient attention を組み込む
163
+ if args.diffusers_xformers:
164
+ accelerator.print("Use xformers by Diffusers")
165
+ set_diffusers_xformers_flag(unet, True)
166
+ else:
167
+ # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
168
+ accelerator.print("Disable Diffusers' xformers")
169
+ set_diffusers_xformers_flag(unet, False)
170
+ train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
171
+
172
+ # 学習を準備する
173
+ if cache_latents:
174
+ vae.to(accelerator.device, dtype=vae_dtype)
175
+ vae.requires_grad_(False)
176
+ vae.eval()
177
+
178
+ train_dataset_group.new_cache_latents(vae, accelerator)
179
+
180
+ vae.to("cpu")
181
+ clean_memory_on_device(accelerator.device)
182
+
183
+ accelerator.wait_for_everyone()
184
+
185
+ # 学習を準備する:モデルを適切な状態にする
186
+ training_models = []
187
+ if args.gradient_checkpointing:
188
+ unet.enable_gradient_checkpointing()
189
+ training_models.append(unet)
190
+
191
+ if args.train_text_encoder:
192
+ accelerator.print("enable text encoder training")
193
+ if args.gradient_checkpointing:
194
+ text_encoder.gradient_checkpointing_enable()
195
+ training_models.append(text_encoder)
196
+ else:
197
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
198
+ text_encoder.requires_grad_(False) # text encoderは学習しない
199
+ if args.gradient_checkpointing:
200
+ text_encoder.gradient_checkpointing_enable()
201
+ text_encoder.train() # required for gradient_checkpointing
202
+ else:
203
+ text_encoder.eval()
204
+
205
+ text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
206
+ strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
207
+
208
+ if not cache_latents:
209
+ vae.requires_grad_(False)
210
+ vae.eval()
211
+ vae.to(accelerator.device, dtype=vae_dtype)
212
+
213
+ for m in training_models:
214
+ m.requires_grad_(True)
215
+
216
+ trainable_params = []
217
+ if args.learning_rate_te is None or not args.train_text_encoder:
218
+ for m in training_models:
219
+ trainable_params.extend(m.parameters())
220
+ else:
221
+ trainable_params = [
222
+ {"params": list(unet.parameters()), "lr": args.learning_rate},
223
+ {"params": list(text_encoder.parameters()), "lr": args.learning_rate_te},
224
+ ]
225
+
226
+ # 学習に必要なクラスを準備する
227
+ accelerator.print("prepare optimizer, data loader etc.")
228
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)
229
+
230
+ # prepare dataloader
231
+ # strategies are set here because they cannot be referenced in another process. Copy them with the dataset
232
+ # some strategies can be None
233
+ train_dataset_group.set_current_strategies()
234
+
235
+ # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
236
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
237
+ train_dataloader = torch.utils.data.DataLoader(
238
+ train_dataset_group,
239
+ batch_size=1,
240
+ shuffle=True,
241
+ collate_fn=collator,
242
+ num_workers=n_workers,
243
+ persistent_workers=args.persistent_data_loader_workers,
244
+ )
245
+
246
+ # 学習ステップ数を計算する
247
+ if args.max_train_epochs is not None:
248
+ args.max_train_steps = args.max_train_epochs * math.ceil(
249
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
250
+ )
251
+ accelerator.print(
252
+ f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
253
+ )
254
+
255
+ # データセット側にも学習ステップを送信
256
+ train_dataset_group.set_max_train_steps(args.max_train_steps)
257
+
258
+ # lr schedulerを用意する
259
+ lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
260
+
261
+ # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
262
+ if args.full_fp16:
263
+ assert (
264
+ args.mixed_precision == "fp16"
265
+ ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
266
+ accelerator.print("enable full fp16 training.")
267
+ unet.to(weight_dtype)
268
+ text_encoder.to(weight_dtype)
269
+
270
+ if args.deepspeed:
271
+ if args.train_text_encoder:
272
+ ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
273
+ else:
274
+ ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
275
+ ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
276
+ ds_model, optimizer, train_dataloader, lr_scheduler
277
+ )
278
+ training_models = [ds_model]
279
+ else:
280
+ # acceleratorがなんかよろしくやってくれるらしい
281
+ if args.train_text_encoder:
282
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
283
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
284
+ )
285
+ else:
286
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
287
+
288
+ # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
289
+ if args.full_fp16:
290
+ train_util.patch_accelerator_for_fp16_training(accelerator)
291
+
292
+ # resumeする
293
+ train_util.resume_from_local_or_hf_if_specified(accelerator, args)
294
+
295
+ # epoch数を計算する
296
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
297
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
298
+ if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
299
+ args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
300
+
301
+ # 学習する
302
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
303
+ accelerator.print("running training / 学習開始")
304
+ accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
305
+ accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
306
+ accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
307
+ accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
308
+ accelerator.print(
309
+ f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
310
+ )
311
+ accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
312
+ accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
313
+
314
+ progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
315
+ global_step = 0
316
+
317
+ noise_scheduler = DDPMScheduler(
318
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
319
+ )
320
+ prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
321
+ if args.zero_terminal_snr:
322
+ custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
323
+
324
+ if accelerator.is_main_process:
325
+ init_kwargs = {}
326
+ if args.wandb_run_name:
327
+ init_kwargs["wandb"] = {"name": args.wandb_run_name}
328
+ if args.log_tracker_config is not None:
329
+ init_kwargs = toml.load(args.log_tracker_config)
330
+ accelerator.init_trackers(
331
+ "finetuning" if args.log_tracker_name is None else args.log_tracker_name,
332
+ config=train_util.get_sanitized_config_or_none(args),
333
+ init_kwargs=init_kwargs,
334
+ )
335
+
336
+ # For --sample_at_first
337
+ train_util.sample_images(
338
+ accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
339
+ )
340
+
341
+ loss_recorder = train_util.LossRecorder()
342
+ for epoch in range(num_train_epochs):
343
+ accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
344
+ current_epoch.value = epoch + 1
345
+
346
+ for m in training_models:
347
+ m.train()
348
+
349
+ for step, batch in enumerate(train_dataloader):
350
+ current_step.value = global_step
351
+ with accelerator.accumulate(*training_models):
352
+ with torch.no_grad():
353
+ if "latents" in batch and batch["latents"] is not None:
354
+ latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
355
+ else:
356
+ # latentに変換
357
+ latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(weight_dtype)
358
+ latents = latents * 0.18215
359
+ b_size = latents.shape[0]
360
+
361
+ with torch.set_grad_enabled(args.train_text_encoder):
362
+ # Get the text embedding for conditioning
363
+ if args.weighted_captions:
364
+ # TODO move to strategy_sd.py
365
+ encoder_hidden_states = get_weighted_text_embeddings(
366
+ tokenize_strategy.tokenizer,
367
+ text_encoder,
368
+ batch["captions"],
369
+ accelerator.device,
370
+ args.max_token_length // 75 if args.max_token_length else 1,
371
+ clip_skip=args.clip_skip,
372
+ )
373
+ else:
374
+ input_ids = batch["input_ids_list"][0].to(accelerator.device)
375
+ encoder_hidden_states = text_encoding_strategy.encode_tokens(
376
+ tokenize_strategy, [text_encoder], [input_ids]
377
+ )[0]
378
+ if args.full_fp16:
379
+ encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
380
+
381
+ # Sample noise, sample a random timestep for each image, and add noise to the latents,
382
+ # with noise offset and/or multires noise if specified
383
+ noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
384
+ args, noise_scheduler, latents
385
+ )
386
+
387
+ # Predict the noise residual
388
+ with accelerator.autocast():
389
+ noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
390
+
391
+ if args.v_parameterization:
392
+ # v-parameterization training
393
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
394
+ else:
395
+ target = noise
396
+
397
+ if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
398
+ # do not mean over batch dimension for snr weight or scale v-pred loss
399
+ loss = train_util.conditional_loss(
400
+ noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
401
+ )
402
+ loss = loss.mean([1, 2, 3])
403
+
404
+ if args.min_snr_gamma:
405
+ loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
406
+ if args.scale_v_pred_loss_like_noise_pred:
407
+ loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
408
+ if args.debiased_estimation_loss:
409
+ loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
410
+
411
+ loss = loss.mean() # mean over batch dimension
412
+ else:
413
+ loss = train_util.conditional_loss(
414
+ noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
415
+ )
416
+
417
+ accelerator.backward(loss)
418
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
419
+ params_to_clip = []
420
+ for m in training_models:
421
+ params_to_clip.extend(m.parameters())
422
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
423
+
424
+ optimizer.step()
425
+ lr_scheduler.step()
426
+ optimizer.zero_grad(set_to_none=True)
427
+
428
+ # Checks if the accelerator has performed an optimization step behind the scenes
429
+ if accelerator.sync_gradients:
430
+ progress_bar.update(1)
431
+ global_step += 1
432
+
433
+ train_util.sample_images(
434
+ accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
435
+ )
436
+
437
+ # 指定ステップごとにモデルを保存
438
+ if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
439
+ accelerator.wait_for_everyone()
440
+ if accelerator.is_main_process:
441
+ src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
442
+ train_util.save_sd_model_on_epoch_end_or_stepwise(
443
+ args,
444
+ False,
445
+ accelerator,
446
+ src_path,
447
+ save_stable_diffusion_format,
448
+ use_safetensors,
449
+ save_dtype,
450
+ epoch,
451
+ num_train_epochs,
452
+ global_step,
453
+ accelerator.unwrap_model(text_encoder),
454
+ accelerator.unwrap_model(unet),
455
+ vae,
456
+ )
457
+
458
+ current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
459
+ if args.logging_dir is not None:
460
+ logs = {"loss": current_loss}
461
+ train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
462
+ accelerator.log(logs, step=global_step)
463
+
464
+ loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
465
+ avr_loss: float = loss_recorder.moving_average
466
+ logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
467
+ progress_bar.set_postfix(**logs)
468
+
469
+ if global_step >= args.max_train_steps:
470
+ break
471
+
472
+ if args.logging_dir is not None:
473
+ logs = {"loss/epoch": loss_recorder.moving_average}
474
+ accelerator.log(logs, step=epoch + 1)
475
+
476
+ accelerator.wait_for_everyone()
477
+
478
+ if args.save_every_n_epochs is not None:
479
+ if accelerator.is_main_process:
480
+ src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
481
+ train_util.save_sd_model_on_epoch_end_or_stepwise(
482
+ args,
483
+ True,
484
+ accelerator,
485
+ src_path,
486
+ save_stable_diffusion_format,
487
+ use_safetensors,
488
+ save_dtype,
489
+ epoch,
490
+ num_train_epochs,
491
+ global_step,
492
+ accelerator.unwrap_model(text_encoder),
493
+ accelerator.unwrap_model(unet),
494
+ vae,
495
+ )
496
+
497
+ train_util.sample_images(
498
+ accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
499
+ )
500
+
501
+ is_main_process = accelerator.is_main_process
502
+ if is_main_process:
503
+ unet = accelerator.unwrap_model(unet)
504
+ text_encoder = accelerator.unwrap_model(text_encoder)
505
+
506
+ accelerator.end_training()
507
+
508
+ if is_main_process and (args.save_state or args.save_state_on_train_end):
509
+ train_util.save_state_on_train_end(args, accelerator)
510
+
511
+ del accelerator # この後メモリを使うのでこれは消す
512
+
513
+ if is_main_process:
514
+ src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
515
+ train_util.save_sd_model_on_train_end(
516
+ args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
517
+ )
518
+ logger.info("model saved.")
519
+
520
+
521
+ def setup_parser() -> argparse.ArgumentParser:
522
+ parser = argparse.ArgumentParser()
523
+
524
+ add_logging_arguments(parser)
525
+ train_util.add_sd_models_arguments(parser)
526
+ train_util.add_dataset_arguments(parser, False, True, True)
527
+ train_util.add_training_arguments(parser, False)
528
+ deepspeed_utils.add_deepspeed_arguments(parser)
529
+ train_util.add_sd_saving_arguments(parser)
530
+ train_util.add_optimizer_arguments(parser)
531
+ config_util.add_config_arguments(parser)
532
+ custom_train_functions.add_custom_train_arguments(parser)
533
+
534
+ parser.add_argument(
535
+ "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する"
536
+ )
537
+ parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
538
+ parser.add_argument(
539
+ "--learning_rate_te",
540
+ type=float,
541
+ default=None,
542
+ help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
543
+ )
544
+ parser.add_argument(
545
+ "--no_half_vae",
546
+ action="store_true",
547
+ help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
548
+ )
549
+
550
+ return parser
551
+
552
+
553
+ if __name__ == "__main__":
554
+ parser = setup_parser()
555
+
556
+ args = parser.parse_args()
557
+ train_util.verify_command_line_training_args(args)
558
+ args = train_util.read_config_from_file(args, parser)
559
+
560
+ train(args)
flags.png ADDED
flow.gif ADDED

Git LFS Details

  • SHA256: e502e5bcbfd25f5d7bad10e0b57a88c8f3b24006792d3a273d7bd964634a8fd9
  • Pointer size: 133 Bytes
  • Size of remote file: 11.3 MB
flux_extract_lora.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # extract approximating LoRA by svd from two FLUX models
2
+ # The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
3
+ # Thanks to cloneofsimo!
4
+
5
+ import argparse
6
+ import json
7
+ import os
8
+ import time
9
+ import torch
10
+ from safetensors.torch import load_file, save_file
11
+ from safetensors import safe_open
12
+ from tqdm import tqdm
13
+ from .library import flux_utils, sai_model_spec
14
+ from .library.utils import MemoryEfficientSafeOpen
15
+ from .library.utils import setup_logging
16
+ from .networks import lora_flux
17
+
18
+ setup_logging()
19
+ import logging
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ from comfy.utils import ProgressBar
24
+ # CLAMP_QUANTILE = 0.99
25
+ # MIN_DIFF = 1e-1
26
+
27
+
28
+ def save_to_file(file_name, state_dict, metadata, dtype):
29
+ if dtype is not None:
30
+ for key in list(state_dict.keys()):
31
+ if type(state_dict[key]) == torch.Tensor:
32
+ state_dict[key] = state_dict[key].to(dtype)
33
+
34
+ save_file(state_dict, file_name, metadata=metadata)
35
+
36
+
37
+ def svd(
38
+ model_org=None,
39
+ model_tuned=None,
40
+ save_to=None,
41
+ dim=4,
42
+ device=None,
43
+ store_device='cpu',
44
+ save_precision=None,
45
+ clamp_quantile=0.99,
46
+ min_diff=0.01,
47
+ no_metadata=False,
48
+ mem_eff_safe_open=False,
49
+ ):
50
+ def str_to_dtype(p):
51
+ if p == "float":
52
+ return torch.float
53
+ if p == "fp16":
54
+ return torch.float16
55
+ if p == "bf16":
56
+ return torch.bfloat16
57
+ return None
58
+
59
+ calc_dtype = torch.float
60
+ save_dtype = str_to_dtype(save_precision)
61
+
62
+ # open models
63
+ lora_weights = {}
64
+ if not mem_eff_safe_open:
65
+ # use original safetensors.safe_open
66
+ open_fn = lambda fn: safe_open(fn, framework="pt")
67
+ else:
68
+ logger.info("Using memory efficient safe_open")
69
+ open_fn = lambda fn: MemoryEfficientSafeOpen(fn)
70
+
71
+ with open_fn(model_org) as fo:
72
+ # filter keys
73
+ keys = []
74
+ for key in fo.keys():
75
+ if not ("single_block" in key or "double_block" in key):
76
+ continue
77
+ if ".bias" in key:
78
+ continue
79
+ if "norm" in key:
80
+ continue
81
+ keys.append(key)
82
+ comfy_pbar = ProgressBar(len(keys))
83
+ with open_fn(model_tuned) as ft:
84
+ for key in tqdm(keys):
85
+ # get tensors and calculate difference
86
+ value_o = fo.get_tensor(key)
87
+ value_t = ft.get_tensor(key)
88
+ mat = value_t.to(calc_dtype) - value_o.to(calc_dtype)
89
+ del value_o, value_t
90
+
91
+ # extract LoRA weights
92
+ if device:
93
+ mat = mat.to(device)
94
+ out_dim, in_dim = mat.size()[0:2]
95
+ rank = min(dim, in_dim, out_dim) # LoRA rank cannot exceed the original dim
96
+
97
+ mat = mat.squeeze()
98
+
99
+ U, S, Vh = torch.linalg.svd(mat)
100
+
101
+ U = U[:, :rank]
102
+ S = S[:rank]
103
+ U = U @ torch.diag(S)
104
+
105
+ Vh = Vh[:rank, :]
106
+
107
+ dist = torch.cat([U.flatten(), Vh.flatten()])
108
+ hi_val = torch.quantile(dist, clamp_quantile)
109
+ low_val = -hi_val
110
+
111
+ U = U.clamp(low_val, hi_val)
112
+ Vh = Vh.clamp(low_val, hi_val)
113
+
114
+ U = U.to(store_device, dtype=save_dtype).contiguous()
115
+ Vh = Vh.to(store_device, dtype=save_dtype).contiguous()
116
+
117
+ print(f"key: {key}, U: {U.size()}, Vh: {Vh.size()}")
118
+ comfy_pbar.update(1)
119
+ lora_weights[key] = (U, Vh)
120
+ del mat, U, S, Vh
121
+
122
+ # make state dict for LoRA
123
+ lora_sd = {}
124
+ for key, (up_weight, down_weight) in lora_weights.items():
125
+ lora_name = key.replace(".weight", "").replace(".", "_")
126
+ lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + lora_name
127
+ lora_sd[lora_name + ".lora_up.weight"] = up_weight
128
+ lora_sd[lora_name + ".lora_down.weight"] = down_weight
129
+ lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0]) # same as rank
130
+
131
+ # minimum metadata
132
+ net_kwargs = {}
133
+ metadata = {
134
+ "ss_v2": str(False),
135
+ "ss_base_model_version": flux_utils.MODEL_VERSION_FLUX_V1,
136
+ "ss_network_module": "networks.lora_flux",
137
+ "ss_network_dim": str(dim),
138
+ "ss_network_alpha": str(float(dim)),
139
+ "ss_network_args": json.dumps(net_kwargs),
140
+ }
141
+
142
+ if not no_metadata:
143
+ title = os.path.splitext(os.path.basename(save_to))[0]
144
+ sai_metadata = sai_model_spec.build_metadata(lora_sd, False, False, False, True, False, time.time(), title, flux="dev")
145
+ metadata.update(sai_metadata)
146
+
147
+ save_to_file(save_to, lora_sd, metadata, save_dtype)
148
+
149
+ logger.info(f"LoRA weights saved to {save_to}")
150
+ return save_to
151
+
152
+
153
+ def setup_parser() -> argparse.ArgumentParser:
154
+ parser = argparse.ArgumentParser()
155
+ parser.add_argument(
156
+ "--save_precision",
157
+ type=str,
158
+ default=None,
159
+ choices=[None, "float", "fp16", "bf16"],
160
+ help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat",
161
+ )
162
+ parser.add_argument(
163
+ "--model_org",
164
+ type=str,
165
+ default=None,
166
+ required=True,
167
+ help="Original model: safetensors file / 元モデル、safetensors",
168
+ )
169
+ parser.add_argument(
170
+ "--model_tuned",
171
+ type=str,
172
+ default=None,
173
+ required=True,
174
+ help="Tuned model, LoRA is difference of `original to tuned`: safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors",
175
+ )
176
+ parser.add_argument(
177
+ "--mem_eff_safe_open",
178
+ action="store_true",
179
+ help="use memory efficient safe_open. This is an experimental feature, use only when memory is not enough."
180
+ " / メモリ効率の良いsafe_openを使用する。実装は実験的なものなので、メモリが足りない場合のみ使用してください。",
181
+ )
182
+ parser.add_argument(
183
+ "--save_to",
184
+ type=str,
185
+ default=None,
186
+ required=True,
187
+ help="destination file name: safetensors file / 保存先のファイル名、safetensors",
188
+ )
189
+ parser.add_argument(
190
+ "--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)"
191
+ )
192
+ parser.add_argument(
193
+ "--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う"
194
+ )
195
+ parser.add_argument(
196
+ "--clamp_quantile",
197
+ type=float,
198
+ default=0.99,
199
+ help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99",
200
+ )
201
+ # parser.add_argument(
202
+ # "--min_diff",
203
+ # type=float,
204
+ # default=0.01,
205
+ # help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /"
206
+ # + "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01",
207
+ # )
208
+ parser.add_argument(
209
+ "--no_metadata",
210
+ action="store_true",
211
+ help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
212
+ + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
213
+ )
214
+ return parser
215
+
216
+
217
+ if __name__ == "__main__":
218
+ parser = setup_parser()
219
+
220
+ args = parser.parse_args()
221
+ svd(**vars(args))
flux_train_comfy.py ADDED
@@ -0,0 +1,806 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training with captions
2
+
3
+ # Swap blocks between CPU and GPU:
4
+ # This implementation is inspired by and based on the work of 2kpr.
5
+ # Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading.
6
+ # The original idea has been adapted and extended to fit the current project's needs.
7
+
8
+ # Key features:
9
+ # - CPU offloading during forward and backward passes
10
+ # - Use of fused optimizer and grad_hook for efficient gradient processing
11
+ # - Per-block fused optimizer instances
12
+
13
+ import argparse
14
+ import copy
15
+ import math
16
+ import os
17
+ from multiprocessing import Value
18
+ import toml
19
+
20
+ from tqdm import tqdm
21
+
22
+ import torch
23
+ from .library.device_utils import init_ipex, clean_memory_on_device
24
+
25
+ init_ipex()
26
+
27
+ from accelerate.utils import set_seed
28
+ from .library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux
29
+ from .library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
30
+
31
+ from .library import train_util as train_util
32
+
33
+ from .library.utils import setup_logging, add_logging_arguments
34
+
35
+ setup_logging()
36
+ import logging
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ from .library import config_util as config_util
41
+
42
+ from .library.config_util import (
43
+ ConfigSanitizer,
44
+ BlueprintGenerator,
45
+ )
46
+ from .library.custom_train_functions import apply_masked_loss, add_custom_train_arguments
47
+
48
+
49
+ class FluxTrainer:
50
+ def __init__(self):
51
+ self.sample_prompts_te_outputs = None
52
+
53
+ def sample_images(self, epoch, global_step, validation_settings):
54
+ image_tensors = flux_train_utils.sample_images(
55
+ self.accelerator, self.args, epoch, global_step, self.unet, self.vae, self.text_encoder, self.sample_prompts_te_outputs, validation_settings)
56
+ return image_tensors
57
+
58
+ def init_train(self, args):
59
+ train_util.verify_training_args(args)
60
+ train_util.prepare_dataset_args(args, True)
61
+ # sdxl_train_util.verify_sdxl_training_args(args)
62
+ deepspeed_utils.prepare_deepspeed_args(args)
63
+ setup_logging(args, reset=True)
64
+
65
+ # temporary: backward compatibility for deprecated options. remove in the future
66
+ if not args.skip_cache_check:
67
+ args.skip_cache_check = args.skip_latents_validity_check
68
+
69
+ if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
70
+ logger.warning(
71
+ "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
72
+ )
73
+ args.cache_text_encoder_outputs = True
74
+
75
+ if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
76
+ logger.warning(
77
+ "cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります"
78
+ )
79
+ args.gradient_checkpointing = True
80
+
81
+ assert (
82
+ args.blocks_to_swap is None or args.blocks_to_swap == 0
83
+ ) or not args.cpu_offload_checkpointing, (
84
+ "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
85
+ )
86
+
87
+ cache_latents = args.cache_latents
88
+ use_dreambooth_method = args.in_json is None
89
+
90
+ if args.seed is not None:
91
+ set_seed(args.seed)
92
+
93
+ # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
94
+ if args.cache_latents:
95
+ latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(
96
+ args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check
97
+ )
98
+ strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
99
+
100
+ # Prepare the dataset
101
+ if args.dataset_class is None:
102
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
103
+ if args.dataset_config is not None:
104
+ logger.info(f"Load dataset config from {args.dataset_config}")
105
+ user_config = config_util.load_user_config(args.dataset_config)
106
+ ignored = ["train_data_dir", "in_json"]
107
+ if any(getattr(args, attr) is not None for attr in ignored):
108
+ logger.warning(
109
+ "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
110
+ ", ".join(ignored)
111
+ )
112
+ )
113
+ else:
114
+ if use_dreambooth_method:
115
+ logger.info("Using DreamBooth method.")
116
+ user_config = {
117
+ "datasets": [
118
+ {
119
+ "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
120
+ args.train_data_dir, args.reg_data_dir
121
+ )
122
+ }
123
+ ]
124
+ }
125
+ else:
126
+ logger.info("Training with captions.")
127
+ user_config = {
128
+ "datasets": [
129
+ {
130
+ "subsets": [
131
+ {
132
+ "image_dir": args.train_data_dir,
133
+ "metadata_file": args.in_json,
134
+ }
135
+ ]
136
+ }
137
+ ]
138
+ }
139
+
140
+ blueprint = blueprint_generator.generate(user_config, args)
141
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
142
+ else:
143
+ train_dataset_group = train_util.load_arbitrary_dataset(args)
144
+
145
+ current_epoch = Value("i", 0)
146
+ current_step = Value("i", 0)
147
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
148
+ collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
149
+
150
+ train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認
151
+
152
+ _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
153
+ if args.debug_dataset:
154
+ if args.cache_text_encoder_outputs:
155
+ strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
156
+ strategy_flux.FluxTextEncoderOutputsCachingStrategy(
157
+ args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
158
+ )
159
+ )
160
+ t5xxl_max_token_length = (
161
+ args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512)
162
+ )
163
+ strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length))
164
+
165
+ train_dataset_group.set_current_strategies()
166
+ train_util.debug_dataset(train_dataset_group, True)
167
+ return
168
+ if len(train_dataset_group) == 0:
169
+ logger.error(
170
+ "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
171
+ )
172
+ return
173
+
174
+ if cache_latents:
175
+ assert (
176
+ train_dataset_group.is_latent_cacheable()
177
+ ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
178
+
179
+ if args.cache_text_encoder_outputs:
180
+ assert (
181
+ train_dataset_group.is_text_encoder_output_cacheable()
182
+ ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
183
+
184
+ # acceleratorを準備する
185
+ logger.info("prepare accelerator")
186
+ accelerator = train_util.prepare_accelerator(args)
187
+
188
+ # mixed precisionに対応した型を用意しておき適宜castする
189
+ weight_dtype, save_dtype = train_util.prepare_dtype(args)
190
+
191
+ # load VAE for caching latents
192
+ ae = None
193
+ if cache_latents:
194
+ ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
195
+ ae.to(accelerator.device, dtype=weight_dtype)
196
+ ae.requires_grad_(False)
197
+ ae.eval()
198
+
199
+ train_dataset_group.new_cache_latents(ae, accelerator)
200
+
201
+ ae.to("cpu") # if no sampling, vae can be deleted
202
+ clean_memory_on_device(accelerator.device)
203
+
204
+ accelerator.wait_for_everyone()
205
+
206
+ # prepare tokenize strategy
207
+ if args.t5xxl_max_token_length is None:
208
+ if is_schnell:
209
+ t5xxl_max_token_length = 256
210
+ else:
211
+ t5xxl_max_token_length = 512
212
+ else:
213
+ t5xxl_max_token_length = args.t5xxl_max_token_length
214
+
215
+ flux_tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)
216
+ strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy)
217
+
218
+ # load clip_l, t5xxl for caching text encoder outputs
219
+ clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
220
+ t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
221
+ clip_l.eval()
222
+ t5xxl.eval()
223
+ clip_l.requires_grad_(False)
224
+ t5xxl.requires_grad_(False)
225
+
226
+ text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask)
227
+ strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
228
+
229
+ # cache text encoder outputs
230
+ sample_prompts_te_outputs = None
231
+ if args.cache_text_encoder_outputs:
232
+ # Text Encodes are eval and no grad here
233
+ clip_l.to(accelerator.device)
234
+ t5xxl.to(accelerator.device)
235
+
236
+ text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy(
237
+ args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask
238
+ )
239
+ strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
240
+
241
+ with accelerator.autocast():
242
+ train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator)
243
+
244
+ # cache sample prompt's embeddings to free text encoder's memory
245
+ if args.sample_prompts is not None:
246
+ logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
247
+
248
+ text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
249
+
250
+ prompts = []
251
+ for line in args.sample_prompts:
252
+ line = line.strip()
253
+ if len(line) > 0 and line[0] != "#":
254
+ prompts.append(line)
255
+
256
+ # preprocess prompts
257
+ for i in range(len(prompts)):
258
+ prompt_dict = prompts[i]
259
+ if isinstance(prompt_dict, str):
260
+ from .library.train_util import line_to_prompt_dict
261
+
262
+ prompt_dict = line_to_prompt_dict(prompt_dict)
263
+ prompts[i] = prompt_dict
264
+ assert isinstance(prompt_dict, dict)
265
+
266
+ # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
267
+ prompt_dict["enum"] = i
268
+ prompt_dict.pop("subset", None)
269
+
270
+ sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
271
+ with accelerator.autocast(), torch.no_grad():
272
+ for prompt_dict in prompts:
273
+ for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
274
+ if p not in sample_prompts_te_outputs:
275
+ logger.info(f"cache Text Encoder outputs for prompt: {p}")
276
+ tokens_and_masks = flux_tokenize_strategy.tokenize(p)
277
+ sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
278
+ flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
279
+ )
280
+ self.sample_prompts_te_outputs = sample_prompts_te_outputs
281
+ accelerator.wait_for_everyone()
282
+
283
+ # now we can delete Text Encoders to free memory
284
+ clip_l = None
285
+ t5xxl = None
286
+ clean_memory_on_device(accelerator.device)
287
+
288
+ # load FLUX
289
+ _, flux = flux_utils.load_flow_model(
290
+ args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
291
+ )
292
+
293
+ if args.gradient_checkpointing:
294
+ flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing)
295
+
296
+ flux.requires_grad_(True)
297
+
298
+ # block swap
299
+
300
+ # backward compatibility
301
+ if args.blocks_to_swap is None:
302
+ blocks_to_swap = args.double_blocks_to_swap or 0
303
+ if args.single_blocks_to_swap is not None:
304
+ blocks_to_swap += args.single_blocks_to_swap // 2
305
+ if blocks_to_swap > 0:
306
+ logger.warning(
307
+ "double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead."
308
+ " / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。"
309
+ )
310
+ logger.info(
311
+ f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}."
312
+ )
313
+ args.blocks_to_swap = blocks_to_swap
314
+ del blocks_to_swap
315
+
316
+ self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
317
+ if self.is_swapping_blocks:
318
+ # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
319
+ # This idea is based on 2kpr's great work. Thank you!
320
+ logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
321
+ flux.enable_block_swap(args.blocks_to_swap, accelerator.device)
322
+
323
+ if not cache_latents:
324
+ # load VAE here if not cached
325
+ ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu")
326
+ ae.requires_grad_(False)
327
+ ae.eval()
328
+ ae.to(accelerator.device, dtype=weight_dtype)
329
+
330
+ training_models = []
331
+ params_to_optimize = []
332
+ training_models.append(flux)
333
+ name_and_params = list(flux.named_parameters())
334
+ # single param group for now
335
+ params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate})
336
+ param_names = [[n for n, _ in name_and_params]]
337
+
338
+ # calculate number of trainable parameters
339
+ n_params = 0
340
+ for group in params_to_optimize:
341
+ for p in group["params"]:
342
+ n_params += p.numel()
343
+
344
+ accelerator.print(f"number of trainable parameters: {n_params}")
345
+
346
+ # 学習に必要なクラスを準備する
347
+ accelerator.print("prepare optimizer, data loader etc.")
348
+
349
+ if args.blockwise_fused_optimizers:
350
+ # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
351
+ # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
352
+ # This balances memory usage and management complexity.
353
+
354
+ # split params into groups. currently different learning rates are not supported
355
+ grouped_params = []
356
+ param_group = {}
357
+ for group in params_to_optimize:
358
+ named_parameters = list(flux.named_parameters())
359
+ assert len(named_parameters) == len(group["params"]), "number of parameters does not match"
360
+ for p, np in zip(group["params"], named_parameters):
361
+ # determine target layer and block index for each parameter
362
+ block_type = "other" # double, single or other
363
+ if np[0].startswith("double_blocks"):
364
+ block_index = int(np[0].split(".")[1])
365
+ block_type = "double"
366
+ elif np[0].startswith("single_blocks"):
367
+ block_index = int(np[0].split(".")[1])
368
+ block_type = "single"
369
+ else:
370
+ block_index = -1
371
+
372
+ param_group_key = (block_type, block_index)
373
+ if param_group_key not in param_group:
374
+ param_group[param_group_key] = []
375
+ param_group[param_group_key].append(p)
376
+
377
+ block_types_and_indices = []
378
+ for param_group_key, param_group in param_group.items():
379
+ block_types_and_indices.append(param_group_key)
380
+ grouped_params.append({"params": param_group, "lr": args.learning_rate})
381
+
382
+ num_params = 0
383
+ for p in param_group:
384
+ num_params += p.numel()
385
+ accelerator.print(f"block {param_group_key}: {num_params} parameters")
386
+
387
+ # prepare optimizers for each group
388
+ optimizers = []
389
+ for group in grouped_params:
390
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
391
+ optimizers.append(optimizer)
392
+ optimizer = optimizers[0] # avoid error in the following code
393
+
394
+ logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers")
395
+
396
+ if train_util.is_schedulefree_optimizer(optimizers[0], args):
397
+ raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
398
+ self.optimizer_train_fn = lambda: None # dummy function
399
+ self.optimizer_eval_fn = lambda: None # dummy function
400
+ else:
401
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
402
+ self.optimizer_train_fn, self.optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
403
+
404
+ # prepare dataloader
405
+ # strategies are set here because they cannot be referenced in another process. Copy them with the dataset
406
+ # some strategies can be None
407
+ train_dataset_group.set_current_strategies()
408
+
409
+ # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
410
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
411
+ train_dataloader = torch.utils.data.DataLoader(
412
+ train_dataset_group,
413
+ batch_size=1,
414
+ shuffle=True,
415
+ collate_fn=collator,
416
+ num_workers=n_workers,
417
+ persistent_workers=args.persistent_data_loader_workers,
418
+ )
419
+
420
+ # 学習ステップ数を計算する
421
+ if args.max_train_epochs is not None:
422
+ args.max_train_steps = args.max_train_epochs * math.ceil(
423
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
424
+ )
425
+ accelerator.print(
426
+ f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
427
+ )
428
+
429
+ # データセット側にも学習ステップを送信
430
+ train_dataset_group.set_max_train_steps(args.max_train_steps)
431
+
432
+ # lr schedulerを用意する
433
+ if args.blockwise_fused_optimizers:
434
+ # prepare lr schedulers for each optimizer
435
+ lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers]
436
+ lr_scheduler = lr_schedulers[0] # avoid error in the following code
437
+ else:
438
+ lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
439
+
440
+ # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
441
+ if args.full_fp16:
442
+ assert (
443
+ args.mixed_precision == "fp16"
444
+ ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
445
+ accelerator.print("enable full fp16 training.")
446
+ flux.to(weight_dtype)
447
+ if clip_l is not None:
448
+ clip_l.to(weight_dtype)
449
+ t5xxl.to(weight_dtype) # TODO check works with fp16 or not
450
+ elif args.full_bf16:
451
+ assert (
452
+ args.mixed_precision == "bf16"
453
+ ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
454
+ accelerator.print("enable full bf16 training.")
455
+ flux.to(weight_dtype)
456
+ if clip_l is not None:
457
+ clip_l.to(weight_dtype)
458
+ t5xxl.to(weight_dtype)
459
+
460
+ # if we don't cache text encoder outputs, move them to device
461
+ if not args.cache_text_encoder_outputs:
462
+ clip_l.to(accelerator.device)
463
+ t5xxl.to(accelerator.device)
464
+
465
+ clean_memory_on_device(accelerator.device)
466
+
467
+ if args.deepspeed:
468
+ ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=flux)
469
+ # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
470
+ ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
471
+ ds_model, optimizer, train_dataloader, lr_scheduler
472
+ )
473
+ training_models = [ds_model]
474
+
475
+ else:
476
+ # accelerator does some magic
477
+ # if we doesn't swap blocks, we can move the model to device
478
+ flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks])
479
+ if self.is_swapping_blocks:
480
+ accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
481
+ optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
482
+
483
+ # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
484
+ if args.full_fp16:
485
+ # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
486
+ # -> But we think it's ok to patch accelerator even if deepspeed is enabled.
487
+ train_util.patch_accelerator_for_fp16_training(accelerator)
488
+
489
+ # resumeする
490
+ train_util.resume_from_local_or_hf_if_specified(accelerator, args)
491
+
492
+ if args.fused_backward_pass:
493
+ # use fused optimizer for backward pass: other optimizers will be supported in the future
494
+ from .library import adafactor_fused
495
+
496
+ adafactor_fused.patch_adafactor_fused(optimizer)
497
+
498
+ for param_group, param_name_group in zip(optimizer.param_groups, param_names):
499
+ for parameter, param_name in zip(param_group["params"], param_name_group):
500
+ if parameter.requires_grad:
501
+
502
+ def create_grad_hook(p_name, p_group):
503
+ def grad_hook(tensor: torch.Tensor):
504
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
505
+ accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
506
+ optimizer.step_param(tensor, p_group)
507
+ tensor.grad = None
508
+
509
+ return grad_hook
510
+
511
+ parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
512
+
513
+ elif args.blockwise_fused_optimizers:
514
+ # prepare for additional optimizers and lr schedulers
515
+ for i in range(1, len(optimizers)):
516
+ optimizers[i] = accelerator.prepare(optimizers[i])
517
+ lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
518
+
519
+ # counters are used to determine when to step the optimizer
520
+ global optimizer_hooked_count
521
+ global num_parameters_per_group
522
+ global parameter_optimizer_map
523
+
524
+ optimizer_hooked_count = {}
525
+ num_parameters_per_group = [0] * len(optimizers)
526
+ parameter_optimizer_map = {}
527
+
528
+ for opt_idx, optimizer in enumerate(optimizers):
529
+ for param_group in optimizer.param_groups:
530
+ for parameter in param_group["params"]:
531
+ if parameter.requires_grad:
532
+
533
+ def grad_hook(parameter: torch.Tensor):
534
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
535
+ accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
536
+
537
+ i = parameter_optimizer_map[parameter]
538
+ optimizer_hooked_count[i] += 1
539
+ if optimizer_hooked_count[i] == num_parameters_per_group[i]:
540
+ optimizers[i].step()
541
+ optimizers[i].zero_grad(set_to_none=True)
542
+
543
+ parameter.register_post_accumulate_grad_hook(grad_hook)
544
+ parameter_optimizer_map[parameter] = opt_idx
545
+ num_parameters_per_group[opt_idx] += 1
546
+
547
+ # epoch数を計算する
548
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
549
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
550
+ if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
551
+ args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
552
+
553
+ # 学習する
554
+ # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
555
+ accelerator.print("running training")
556
+ accelerator.print(f" num examples: {train_dataset_group.num_train_images}")
557
+ accelerator.print(f" num batches per epoch: {len(train_dataloader)}")
558
+ accelerator.print(f" num epochs: {num_train_epochs}")
559
+ accelerator.print(
560
+ f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
561
+ )
562
+ # accelerator.print(
563
+ # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
564
+ # )
565
+ accelerator.print(f" gradient accumulation steps = {args.gradient_accumulation_steps}")
566
+ accelerator.print(f" total optimization steps: {args.max_train_steps}")
567
+
568
+ progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
569
+ self.global_step = 0
570
+
571
+ noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
572
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
573
+
574
+ if accelerator.is_main_process:
575
+ init_kwargs = {}
576
+ if args.wandb_run_name:
577
+ init_kwargs["wandb"] = {"name": args.wandb_run_name}
578
+ if args.log_tracker_config is not None:
579
+ init_kwargs = toml.load(args.log_tracker_config)
580
+ accelerator.init_trackers(
581
+ "finetuning" if args.log_tracker_name is None else args.log_tracker_name,
582
+ config=train_util.get_sanitized_config_or_none(args),
583
+ init_kwargs=init_kwargs,
584
+ )
585
+
586
+ if self.is_swapping_blocks:
587
+ accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
588
+
589
+ # For --sample_at_first
590
+ #flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs)
591
+
592
+ self.loss_recorder = train_util.LossRecorder()
593
+ epoch = 0 # avoid error when max_train_steps is 0
594
+
595
+ self.tokens_and_masks = tokens_and_masks
596
+ self.num_train_epochs = num_train_epochs
597
+ self.current_epoch = current_epoch
598
+ self.args = args
599
+ self.accelerator = accelerator
600
+ self.unet = flux
601
+ self.vae = ae
602
+ self.text_encoder = [clip_l, t5xxl]
603
+ self.save_dtype = save_dtype
604
+
605
+ def training_loop(break_at_steps, epoch):
606
+ global optimizer_hooked_count
607
+ steps_done = 0
608
+ #accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
609
+ progress_bar.set_description(f"Epoch {epoch + 1}/{num_train_epochs} - steps")
610
+ current_epoch.value = epoch + 1
611
+
612
+ for m in training_models:
613
+ m.train()
614
+
615
+ for step, batch in enumerate(train_dataloader):
616
+ current_step.value = self.global_step
617
+
618
+ if args.blockwise_fused_optimizers:
619
+ optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
620
+
621
+ with accelerator.accumulate(*training_models):
622
+ if "latents" in batch and batch["latents"] is not None:
623
+ latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
624
+ else:
625
+ with torch.no_grad():
626
+ # encode images to latents. images are [-1, 1]
627
+ latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype)
628
+
629
+ # NaNが含まれていれば警告を表示し0に置き換える
630
+ if torch.any(torch.isnan(latents)):
631
+ accelerator.print("NaN found in latents, replacing with zeros")
632
+ latents = torch.nan_to_num(latents, 0, out=latents)
633
+
634
+ text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
635
+ if text_encoder_outputs_list is not None:
636
+ text_encoder_conds = text_encoder_outputs_list
637
+ else:
638
+ # not cached or training, so get from text encoders
639
+ tokens_and_masks = batch["input_ids_list"]
640
+ with torch.no_grad():
641
+ input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
642
+ text_encoder_conds = text_encoding_strategy.encode_tokens(
643
+ flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask
644
+ )
645
+ if args.full_fp16:
646
+ text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
647
+
648
+ # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps
649
+
650
+ # Sample noise that we'll add to the latents
651
+ noise = torch.randn_like(latents)
652
+ bsz = latents.shape[0]
653
+
654
+ # get noisy model input and timesteps
655
+ noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
656
+ args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype
657
+ )
658
+
659
+ # pack latents and get img_ids
660
+ packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
661
+ packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
662
+ img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
663
+
664
+ # get guidance: ensure args.guidance_scale is float
665
+ guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
666
+
667
+ # call model
668
+ l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
669
+ if not args.apply_t5_attn_mask:
670
+ t5_attn_mask = None
671
+
672
+ if args.bypass_flux_guidance:
673
+ flux_utils.bypass_flux_guidance(flux)
674
+
675
+ with accelerator.autocast():
676
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
677
+ model_pred = flux(
678
+ img=packed_noisy_model_input,
679
+ img_ids=img_ids,
680
+ txt=t5_out,
681
+ txt_ids=txt_ids,
682
+ y=l_pooled,
683
+ timesteps=timesteps / 1000,
684
+ guidance=guidance_vec,
685
+ txt_attention_mask=t5_attn_mask,
686
+ )
687
+
688
+ # unpack latents
689
+ model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
690
+
691
+ if args.bypass_flux_guidance:
692
+ flux_utils.restore_flux_guidance(flux)
693
+
694
+ # apply model prediction type
695
+ model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
696
+
697
+ # flow matching loss: this is different from SD3
698
+ target = noise - latents
699
+
700
+ # calculate loss
701
+ huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
702
+ loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
703
+ if weighting is not None:
704
+ loss = loss * weighting
705
+ if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
706
+ loss = apply_masked_loss(loss, batch)
707
+ loss = loss.mean([1, 2, 3])
708
+
709
+ loss_weights = batch["loss_weights"] # 各sampleごとのweight
710
+ loss = loss * loss_weights
711
+ loss = loss.mean()
712
+
713
+ # backward
714
+ accelerator.backward(loss)
715
+
716
+ if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
717
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
718
+ params_to_clip = []
719
+ for m in training_models:
720
+ params_to_clip.extend(m.parameters())
721
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
722
+
723
+ optimizer.step()
724
+ lr_scheduler.step()
725
+ optimizer.zero_grad(set_to_none=True)
726
+ else:
727
+ # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
728
+ lr_scheduler.step()
729
+ if args.blockwise_fused_optimizers:
730
+ for i in range(1, len(optimizers)):
731
+ lr_schedulers[i].step()
732
+
733
+ # Checks if the accelerator has performed an optimization step behind the scenes
734
+ if accelerator.sync_gradients:
735
+ progress_bar.update(1)
736
+ self.global_step += 1
737
+
738
+
739
+ current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
740
+ if len(accelerator.trackers) > 0:
741
+ logs = {"loss": current_loss}
742
+ train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
743
+
744
+ accelerator.log(logs, step=self.global_step)
745
+
746
+ self.loss_recorder.add(epoch=epoch, step=step, loss=current_loss, global_step=self.global_step)
747
+ avr_loss: float = self.loss_recorder.moving_average
748
+ logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
749
+ progress_bar.set_postfix(**logs)
750
+
751
+ if self.global_step >= break_at_steps:
752
+ break
753
+ steps_done += 1
754
+
755
+ if len(accelerator.trackers) > 0:
756
+ logs = {"loss/epoch": self.loss_recorder.moving_average}
757
+ accelerator.log(logs, step=epoch + 1)
758
+ return steps_done
759
+
760
+ return training_loop
761
+
762
+ def setup_parser() -> argparse.ArgumentParser:
763
+ parser = argparse.ArgumentParser()
764
+
765
+ add_logging_arguments(parser)
766
+ train_util.add_sd_models_arguments(parser) # TODO split this
767
+ train_util.add_dataset_arguments(parser, True, True, True)
768
+ train_util.add_training_arguments(parser, False)
769
+ train_util.add_masked_loss_arguments(parser)
770
+ deepspeed_utils.add_deepspeed_arguments(parser)
771
+ train_util.add_sd_saving_arguments(parser)
772
+ train_util.add_optimizer_arguments(parser)
773
+ config_util.add_config_arguments(parser)
774
+ add_custom_train_arguments(parser) # TODO remove this from here
775
+ train_util.add_dit_training_arguments(parser)
776
+ flux_train_utils.add_flux_train_arguments(parser)
777
+
778
+ parser.add_argument(
779
+ "--mem_eff_save",
780
+ action="store_true",
781
+ help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う",
782
+ )
783
+
784
+ parser.add_argument(
785
+ "--fused_optimizer_groups",
786
+ type=int,
787
+ default=None,
788
+ help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます",
789
+ )
790
+ parser.add_argument(
791
+ "--blockwise_fused_optimizers",
792
+ action="store_true",
793
+ help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする",
794
+ )
795
+ parser.add_argument(
796
+ "--skip_latents_validity_check",
797
+ action="store_true",
798
+ help="skip latents validity check / latentsの正当性チェックをスキップする",
799
+ )
800
+
801
+ parser.add_argument(
802
+ "--cpu_offload_checkpointing",
803
+ action="store_true",
804
+ help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
805
+ )
806
+ return parser
flux_train_network_comfy.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import copy
3
+ import math
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+ import argparse
6
+ from .library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
7
+ from .train_network import NetworkTrainer, clean_memory_on_device, setup_parser
8
+
9
+ from accelerate import Accelerator
10
+
11
+
12
+ import logging
13
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
14
+ logger = logging.getLogger(__name__)
15
+
16
+ class FluxNetworkTrainer(NetworkTrainer):
17
+ def __init__(self):
18
+ super().__init__()
19
+ self.sample_prompts_te_outputs = None
20
+ self.is_schnell: Optional[bool] = None
21
+ self.is_swapping_blocks: bool = False
22
+
23
+ def assert_extra_args(self, args, train_dataset_group):
24
+ super().assert_extra_args(args, train_dataset_group)
25
+ # sdxl_train_util.verify_sdxl_training_args(args)
26
+
27
+ if args.fp8_base_unet:
28
+ args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
29
+
30
+ if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
31
+ logger.warning(
32
+ "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
33
+ )
34
+ args.cache_text_encoder_outputs = True
35
+
36
+ if args.cache_text_encoder_outputs:
37
+ assert (
38
+ train_dataset_group.is_text_encoder_output_cacheable()
39
+ ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
40
+
41
+ # prepare CLIP-L/T5XXL training flags
42
+ self.train_clip_l = not args.network_train_unet_only
43
+ self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
44
+
45
+ if args.max_token_length is not None:
46
+ logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
47
+
48
+ assert (
49
+ args.blocks_to_swap is None or args.blocks_to_swap == 0
50
+ ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
51
+
52
+ train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
53
+
54
+ def load_target_model(self, args, weight_dtype, accelerator):
55
+ # currently offload to cpu for some models
56
+
57
+ # if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
58
+ loading_dtype = None if args.fp8_base else weight_dtype
59
+
60
+ # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
61
+ self.is_schnell, model = flux_utils.load_flow_model(
62
+ args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
63
+ )
64
+ if args.fp8_base:
65
+ # check dtype of model
66
+ if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2fnuz:
67
+ raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
68
+ elif model.dtype == torch.float8_e4m3fn or model.dtype == torch.float8_e5m2:
69
+ logger.info(f"Loaded {model.dtype} FLUX model")
70
+
71
+ self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
72
+ if self.is_swapping_blocks:
73
+ # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
74
+ logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
75
+ model.enable_block_swap(args.blocks_to_swap, accelerator.device)
76
+
77
+ clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
78
+ clip_l.eval()
79
+
80
+ # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
81
+ if args.fp8_base and not args.fp8_base_unet:
82
+ loading_dtype = None # as is
83
+ else:
84
+ loading_dtype = weight_dtype
85
+
86
+ # loading t5xxl to cpu takes a long time, so we should load to gpu in future
87
+ t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
88
+ t5xxl.eval()
89
+ if args.fp8_base and not args.fp8_base_unet:
90
+ # check dtype of model
91
+ if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
92
+ raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
93
+ elif t5xxl.dtype == torch.float8_e4m3fn:
94
+ logger.info("Loaded fp8 T5XXL model")
95
+
96
+ ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
97
+
98
+ return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
99
+
100
+ def get_tokenize_strategy(self, args):
101
+ _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
102
+
103
+ if args.t5xxl_max_token_length is None:
104
+ if is_schnell:
105
+ t5xxl_max_token_length = 256
106
+ else:
107
+ t5xxl_max_token_length = 512
108
+ else:
109
+ t5xxl_max_token_length = args.t5xxl_max_token_length
110
+
111
+ logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
112
+ return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
113
+
114
+ def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
115
+ return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
116
+
117
+ def get_latents_caching_strategy(self, args):
118
+ latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
119
+ return latents_caching_strategy
120
+
121
+ def get_text_encoding_strategy(self, args):
122
+ return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
123
+
124
+ def post_process_network(self, args, accelerator, network, text_encoders, unet):
125
+ # check t5xxl is trained or not
126
+ self.train_t5xxl = network.train_t5xxl
127
+
128
+ if self.train_t5xxl and args.cache_text_encoder_outputs:
129
+ raise ValueError(
130
+ "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
131
+ )
132
+
133
+ def get_models_for_text_encoding(self, args, accelerator, text_encoders):
134
+ if args.cache_text_encoder_outputs:
135
+ if self.train_clip_l and not self.train_t5xxl:
136
+ return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
137
+ else:
138
+ return None # no text encoders are needed for encoding because both are cached
139
+ else:
140
+ return text_encoders # both CLIP-L and T5XXL are needed for encoding
141
+
142
+ def get_text_encoders_train_flags(self, args, text_encoders):
143
+ return [self.train_clip_l, self.train_t5xxl]
144
+
145
+ def get_text_encoder_outputs_caching_strategy(self, args):
146
+ if args.cache_text_encoder_outputs:
147
+ # if the text encoders is trained, we need tokenization, so is_partial is True
148
+ return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
149
+ args.cache_text_encoder_outputs_to_disk,
150
+ args.text_encoder_batch_size,
151
+ args.skip_cache_check,
152
+ is_partial=self.train_clip_l or self.train_t5xxl,
153
+ apply_t5_attn_mask=args.apply_t5_attn_mask,
154
+ )
155
+ else:
156
+ return None
157
+
158
+ def cache_text_encoder_outputs_if_needed(
159
+ self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
160
+ ):
161
+ if args.cache_text_encoder_outputs:
162
+ if not args.lowram:
163
+ # reduce memory consumption
164
+ logger.info("move vae and unet to cpu to save memory")
165
+ org_vae_device = vae.device
166
+ org_unet_device = unet.device
167
+ vae.to("cpu")
168
+ unet.to("cpu")
169
+ clean_memory_on_device(accelerator.device)
170
+
171
+ # When TE is not be trained, it will not be prepared so we need to use explicit autocast
172
+ logger.info("move text encoders to gpu")
173
+ text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
174
+ text_encoders[1].to(accelerator.device)
175
+
176
+ if text_encoders[1].dtype == torch.float8_e4m3fn:
177
+ # if we load fp8 weights, the model is already fp8, so we use it as is
178
+ self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
179
+ else:
180
+ # otherwise, we need to convert it to target dtype
181
+ text_encoders[1].to(weight_dtype)
182
+
183
+ with accelerator.autocast():
184
+ dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
185
+
186
+ # cache sample prompts
187
+ if args.sample_prompts is not None:
188
+ logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
189
+
190
+ tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
191
+ text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
192
+
193
+ prompts = []
194
+ for line in args.sample_prompts:
195
+ line = line.strip()
196
+ if len(line) > 0 and line[0] != "#":
197
+ prompts.append(line)
198
+
199
+ # preprocess prompts
200
+ for i in range(len(prompts)):
201
+ prompt_dict = prompts[i]
202
+ if isinstance(prompt_dict, str):
203
+ from .library.train_util import line_to_prompt_dict
204
+
205
+ prompt_dict = line_to_prompt_dict(prompt_dict)
206
+ prompts[i] = prompt_dict
207
+ assert isinstance(prompt_dict, dict)
208
+
209
+ # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
210
+ prompt_dict["enum"] = i
211
+ prompt_dict.pop("subset", None)
212
+
213
+ sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
214
+ with accelerator.autocast(), torch.no_grad():
215
+ for prompt_dict in prompts:
216
+ for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
217
+ if p not in sample_prompts_te_outputs:
218
+ logger.info(f"cache Text Encoder outputs for prompt: {p}")
219
+ tokens_and_masks = tokenize_strategy.tokenize(p)
220
+ sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
221
+ tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask
222
+ )
223
+ self.sample_prompts_te_outputs = sample_prompts_te_outputs
224
+
225
+ accelerator.wait_for_everyone()
226
+
227
+ # move back to cpu
228
+ if not self.is_train_text_encoder(args):
229
+ logger.info("move CLIP-L back to cpu")
230
+ text_encoders[0].to("cpu")
231
+ logger.info("move t5XXL back to cpu")
232
+ text_encoders[1].to("cpu")
233
+ clean_memory_on_device(accelerator.device)
234
+
235
+ if not args.lowram:
236
+ logger.info("move vae and unet back to original device")
237
+ vae.to(org_vae_device)
238
+ unet.to(org_unet_device)
239
+ else:
240
+ # Text Encoder
241
+ text_encoders[0].to(accelerator.device, dtype=weight_dtype)
242
+ text_encoders[1].to(accelerator.device)
243
+
244
+ def sample_images(self, epoch, global_step, validation_settings):
245
+ text_encoders = self.get_models_for_text_encoding(self.args, self.accelerator, self.text_encoder)
246
+
247
+ image_tensors = flux_train_utils.sample_images(
248
+ self.accelerator, self.args, epoch, global_step, self.unet, self.vae, text_encoders, self.sample_prompts_te_outputs, validation_settings)
249
+ clean_memory_on_device(self.accelerator.device)
250
+ return image_tensors
251
+
252
+ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
253
+ noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
254
+ self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
255
+ return noise_scheduler
256
+
257
+ def encode_images_to_latents(self, args, accelerator, vae, images):
258
+ return vae.encode(images)
259
+
260
+ def shift_scale_latents(self, args, latents):
261
+ return latents
262
+
263
+ def get_noise_pred_and_target(
264
+ self,
265
+ args,
266
+ accelerator,
267
+ noise_scheduler,
268
+ latents,
269
+ batch,
270
+ text_encoder_conds,
271
+ unet: flux_models.Flux,
272
+ network,
273
+ weight_dtype,
274
+ train_unet,
275
+ ):
276
+ # Sample noise that we'll add to the latents
277
+ noise = torch.randn_like(latents)
278
+ bsz = latents.shape[0]
279
+
280
+ # get noisy model input and timesteps
281
+ noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
282
+ args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
283
+ )
284
+
285
+ # pack latents and get img_ids
286
+ packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
287
+ packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
288
+ img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
289
+
290
+ # get guidance
291
+ # ensure guidance_scale in args is float
292
+ guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
293
+
294
+ # ensure the hidden state will require grad
295
+ if args.gradient_checkpointing:
296
+ noisy_model_input.requires_grad_(True)
297
+ for t in text_encoder_conds:
298
+ if t is not None and t.dtype.is_floating_point:
299
+ t.requires_grad_(True)
300
+ img_ids.requires_grad_(True)
301
+ guidance_vec.requires_grad_(True)
302
+
303
+ # Predict the noise residual
304
+ l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
305
+ if not args.apply_t5_attn_mask:
306
+ t5_attn_mask = None
307
+
308
+ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
309
+ # normal forward
310
+ with accelerator.autocast():
311
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
312
+ model_pred = unet(
313
+ img=img,
314
+ img_ids=img_ids,
315
+ txt=t5_out,
316
+ txt_ids=txt_ids,
317
+ y=l_pooled,
318
+ timesteps=timesteps / 1000,
319
+ guidance=guidance_vec,
320
+ txt_attention_mask=t5_attn_mask,
321
+ )
322
+ """
323
+ else:
324
+ # split forward to reduce memory usage
325
+ assert network.train_blocks == "single", "train_blocks must be single for split mode"
326
+ with accelerator.autocast():
327
+ # move flux lower to cpu, and then move flux upper to gpu
328
+ unet.to("cpu")
329
+ clean_memory_on_device(accelerator.device)
330
+ self.flux_upper.to(accelerator.device)
331
+
332
+ # upper model does not require grad
333
+ with torch.no_grad():
334
+ intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
335
+ img=packed_noisy_model_input,
336
+ img_ids=img_ids,
337
+ txt=t5_out,
338
+ txt_ids=txt_ids,
339
+ y=l_pooled,
340
+ timesteps=timesteps / 1000,
341
+ guidance=guidance_vec,
342
+ txt_attention_mask=t5_attn_mask,
343
+ )
344
+
345
+ # move flux upper back to cpu, and then move flux lower to gpu
346
+ self.flux_upper.to("cpu")
347
+ clean_memory_on_device(accelerator.device)
348
+ unet.to(accelerator.device)
349
+
350
+ # lower model requires grad
351
+ intermediate_img.requires_grad_(True)
352
+ intermediate_txt.requires_grad_(True)
353
+ vec.requires_grad_(True)
354
+ pe.requires_grad_(True)
355
+ model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
356
+ """
357
+
358
+ return model_pred
359
+
360
+ if args.bypass_flux_guidance:
361
+ flux_utils.bypass_flux_guidance(unet)
362
+
363
+ model_pred = call_dit(
364
+ img=packed_noisy_model_input,
365
+ img_ids=img_ids,
366
+ t5_out=t5_out,
367
+ txt_ids=txt_ids,
368
+ l_pooled=l_pooled,
369
+ timesteps=timesteps,
370
+ guidance_vec=guidance_vec,
371
+ t5_attn_mask=t5_attn_mask,
372
+ )
373
+
374
+ # unpack latents
375
+ model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
376
+
377
+ if args.bypass_flux_guidance: #for flex
378
+ flux_utils.restore_flux_guidance(unet)
379
+
380
+ # apply model prediction type
381
+ model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
382
+
383
+ # flow matching loss: this is different from SD3
384
+ target = noise - latents
385
+
386
+ # differential output preservation
387
+ if "custom_attributes" in batch:
388
+ diff_output_pr_indices = []
389
+ for i, custom_attributes in enumerate(batch["custom_attributes"]):
390
+ if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
391
+ diff_output_pr_indices.append(i)
392
+
393
+ if len(diff_output_pr_indices) > 0:
394
+ network.set_multiplier(0.0)
395
+ unet.prepare_block_swap_before_forward()
396
+ with torch.no_grad():
397
+ model_pred_prior = call_dit(
398
+ img=packed_noisy_model_input[diff_output_pr_indices],
399
+ img_ids=img_ids[diff_output_pr_indices],
400
+ t5_out=t5_out[diff_output_pr_indices],
401
+ txt_ids=txt_ids[diff_output_pr_indices],
402
+ l_pooled=l_pooled[diff_output_pr_indices],
403
+ timesteps=timesteps[diff_output_pr_indices],
404
+ guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
405
+ t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
406
+ )
407
+ network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
408
+
409
+ model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width)
410
+ model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
411
+ args,
412
+ model_pred_prior,
413
+ noisy_model_input[diff_output_pr_indices],
414
+ sigmas[diff_output_pr_indices] if sigmas is not None else None,
415
+ )
416
+ target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
417
+
418
+ return model_pred, target, timesteps, weighting
419
+
420
+ def post_process_loss(self, loss, args, timesteps, noise_scheduler):
421
+ return loss
422
+
423
+ def get_sai_model_spec(self, args):
424
+ return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
425
+
426
+ def update_metadata(self, metadata, args):
427
+ metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
428
+ metadata["ss_weighting_scheme"] = args.weighting_scheme
429
+ metadata["ss_logit_mean"] = args.logit_mean
430
+ metadata["ss_logit_std"] = args.logit_std
431
+ metadata["ss_mode_scale"] = args.mode_scale
432
+ metadata["ss_guidance_scale"] = args.guidance_scale
433
+ metadata["ss_timestep_sampling"] = args.timestep_sampling
434
+ metadata["ss_sigmoid_scale"] = args.sigmoid_scale
435
+ metadata["ss_model_prediction_type"] = args.model_prediction_type
436
+ metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
437
+
438
+ def is_text_encoder_not_needed_for_training(self, args):
439
+ return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
440
+
441
+ def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
442
+ if index == 0: # CLIP-L
443
+ return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
444
+ else: # T5XXL
445
+ text_encoder.encoder.embed_tokens.requires_grad_(True)
446
+
447
+ def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
448
+ if index == 0: # CLIP-L
449
+ logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
450
+ text_encoder.to(te_weight_dtype) # fp8
451
+ text_encoder.text_model.embeddings.to(dtype=weight_dtype)
452
+ else: # T5XXL
453
+
454
+ def prepare_fp8(text_encoder, target_dtype):
455
+ def forward_hook(module):
456
+ def forward(hidden_states):
457
+ hidden_gelu = module.act(module.wi_0(hidden_states))
458
+ hidden_linear = module.wi_1(hidden_states)
459
+ hidden_states = hidden_gelu * hidden_linear
460
+ hidden_states = module.dropout(hidden_states)
461
+
462
+ hidden_states = module.wo(hidden_states)
463
+ return hidden_states
464
+
465
+ return forward
466
+
467
+ for module in text_encoder.modules():
468
+ if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
469
+ # print("set", module.__class__.__name__, "to", target_dtype)
470
+ module.to(target_dtype)
471
+ if module.__class__.__name__ in ["T5DenseGatedActDense"]:
472
+ # print("set", module.__class__.__name__, "hooks")
473
+ module.forward = forward_hook(module)
474
+
475
+ if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
476
+ logger.info(f"T5XXL already prepared for fp8")
477
+ else:
478
+ logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
479
+ text_encoder.to(te_weight_dtype) # fp8
480
+ prepare_fp8(text_encoder, weight_dtype)
481
+
482
+ def prepare_unet_with_accelerator(
483
+ self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
484
+ ) -> torch.nn.Module:
485
+ if not self.is_swapping_blocks:
486
+ return super().prepare_unet_with_accelerator(args, accelerator, unet)
487
+
488
+ # if we doesn't swap blocks, we can move the model to device
489
+ flux: flux_models.Flux = unet
490
+ flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks])
491
+ accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
492
+ accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
493
+
494
+ return flux
495
+
496
+
497
+ def setup_parser() -> argparse.ArgumentParser:
498
+ parser = setup_parser()
499
+ train_util.add_dit_training_arguments(parser)
500
+ flux_train_utils.add_flux_train_arguments(parser)
hf_token.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "hf_token": "your_token_here"
3
+ }
icon.png ADDED
install.js ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ module.exports = {
2
+ run: [
3
+ {
4
+ method: "shell.run",
5
+ params: {
6
+ venv: "env",
7
+ message: [
8
+ "git config --global --add safe.directory '*'",
9
+ "git clone -b sd3 https://github.com/kohya-ss/sd-scripts"
10
+ ]
11
+ }
12
+ },
13
+ {
14
+ method: "shell.run",
15
+ params: {
16
+ path: "sd-scripts",
17
+ venv: "../env",
18
+ message: [
19
+ "uv pip install -r requirements.txt",
20
+ ]
21
+ }
22
+ },
23
+ {
24
+ method: "shell.run",
25
+ params: {
26
+ venv: "env",
27
+ message: [
28
+ "pip uninstall -y diffusers[torch] torch torchaudio torchvision",
29
+ "uv pip install -r requirements.txt",
30
+ ]
31
+ }
32
+ },
33
+ {
34
+ method: "script.start",
35
+ params: {
36
+ uri: "torch.js",
37
+ params: {
38
+ venv: "env",
39
+ // xformers: true // uncomment this line if your project requires xformers
40
+ }
41
+ }
42
+ },
43
+ {
44
+ method: "fs.link",
45
+ params: {
46
+ drive: {
47
+ vae: "models/vae",
48
+ clip: "models/clip",
49
+ unet: "models/unet",
50
+ loras: "outputs",
51
+ },
52
+ peers: [
53
+ "https://github.com/pinokiofactory/stable-diffusion-webui-forge.git",
54
+ "https://github.com/pinokiofactory/comfy.git",
55
+ "https://github.com/cocktailpeanutlabs/comfyui.git",
56
+ "https://github.com/cocktailpeanutlabs/fooocus.git",
57
+ "https://github.com/cocktailpeanutlabs/automatic1111.git",
58
+ ]
59
+ }
60
+ },
61
+ // {
62
+ // method: "fs.download",
63
+ // params: {
64
+ // uri: [
65
+ // "https://huggingface.co/comfyanonymous/flux_text_encoders/resolve/main/clip_l.safetensors?download=true",
66
+ // "https://huggingface.co/comfyanonymous/flux_text_encoders/resolve/main/t5xxl_fp16.safetensors?download=true",
67
+ // ],
68
+ // dir: "models/clip"
69
+ // }
70
+ // },
71
+ // {
72
+ // method: "fs.download",
73
+ // params: {
74
+ // uri: [
75
+ // "https://huggingface.co/cocktailpeanut/xulf-dev/resolve/main/ae.sft?download=true",
76
+ // ],
77
+ // dir: "models/vae"
78
+ // }
79
+ // },
80
+ // {
81
+ // method: "fs.download",
82
+ // params: {
83
+ // uri: [
84
+ // "https://huggingface.co/cocktailpeanut/xulf-dev/resolve/main/flux1-dev.sft?download=true",
85
+ // ],
86
+ // dir: "models/unet"
87
+ // }
88
+ // },
89
+ {
90
+ method: "fs.link",
91
+ params: {
92
+ venv: "env"
93
+ }
94
+ }
95
+ ]
96
+ }
library/__init__.py ADDED
File without changes
library/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (163 Bytes). View file
 
library/__pycache__/config_util.cpython-310.pyc ADDED
Binary file (20.2 kB). View file
 
library/__pycache__/custom_offloading_utils.cpython-310.pyc ADDED
Binary file (6.98 kB). View file
 
library/__pycache__/custom_train_functions.cpython-310.pyc ADDED
Binary file (13.5 kB). View file
 
library/__pycache__/deepspeed_utils.cpython-310.pyc ADDED
Binary file (4.79 kB). View file
 
library/__pycache__/device_utils.cpython-310.pyc ADDED
Binary file (2.07 kB). View file
 
library/__pycache__/flux_models.cpython-310.pyc ADDED
Binary file (30.6 kB). View file
 
library/__pycache__/flux_train_utils.cpython-310.pyc ADDED
Binary file (14.8 kB). View file
 
library/__pycache__/flux_utils.cpython-310.pyc ADDED
Binary file (16.5 kB). View file
 
library/__pycache__/huggingface_util.cpython-310.pyc ADDED
Binary file (2.79 kB). View file
 
library/__pycache__/model_util.cpython-310.pyc ADDED
Binary file (32.8 kB). View file
 
library/__pycache__/original_unet.cpython-310.pyc ADDED
Binary file (44.1 kB). View file
 
library/__pycache__/sai_model_spec.cpython-310.pyc ADDED
Binary file (5.68 kB). View file
 
library/__pycache__/sd3_models.cpython-310.pyc ADDED
Binary file (38.8 kB). View file
 
library/__pycache__/sd3_utils.cpython-310.pyc ADDED
Binary file (8.48 kB). View file
 
library/__pycache__/strategy_base.cpython-310.pyc ADDED
Binary file (17.4 kB). View file
 
library/__pycache__/strategy_sd.cpython-310.pyc ADDED
Binary file (6.7 kB). View file
 
library/__pycache__/train_util.cpython-310.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa71a44895d0a006e41ba9fadbd0177a9ad5499cc89aeb2266aa1c7a9597e82e
3
+ size 164434
library/__pycache__/utils.cpython-310.pyc ADDED
Binary file (15.9 kB). View file
 
library/adafactor_fused.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from transformers import Adafactor
4
+
5
+ # stochastic rounding for bfloat16
6
+ # The implementation was provided by 2kpr. Thank you very much!
7
+
8
+ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
9
+ """
10
+ copies source into target using stochastic rounding
11
+
12
+ Args:
13
+ target: the target tensor with dtype=bfloat16
14
+ source: the target tensor with dtype=float32
15
+ """
16
+ # create a random 16 bit integer
17
+ result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
18
+
19
+ # add the random number to the lower 16 bit of the mantissa
20
+ result.add_(source.view(dtype=torch.int32))
21
+
22
+ # mask off the lower 16 bit of the mantissa
23
+ result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
24
+
25
+ # copy the higher 16 bit into the target tensor
26
+ target.copy_(result.view(dtype=torch.float32))
27
+
28
+ del result
29
+
30
+
31
+ @torch.no_grad()
32
+ def adafactor_step_param(self, p, group):
33
+ if p.grad is None:
34
+ return
35
+ grad = p.grad
36
+ if grad.dtype in {torch.float16, torch.bfloat16}:
37
+ grad = grad.float()
38
+ if grad.is_sparse:
39
+ raise RuntimeError("Adafactor does not support sparse gradients.")
40
+
41
+ state = self.state[p]
42
+ grad_shape = grad.shape
43
+
44
+ factored, use_first_moment = Adafactor._get_options(group, grad_shape)
45
+ # State Initialization
46
+ if len(state) == 0:
47
+ state["step"] = 0
48
+
49
+ if use_first_moment:
50
+ # Exponential moving average of gradient values
51
+ state["exp_avg"] = torch.zeros_like(grad)
52
+ if factored:
53
+ state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
54
+ state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
55
+ else:
56
+ state["exp_avg_sq"] = torch.zeros_like(grad)
57
+
58
+ state["RMS"] = 0
59
+ else:
60
+ if use_first_moment:
61
+ state["exp_avg"] = state["exp_avg"].to(grad)
62
+ if factored:
63
+ state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
64
+ state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
65
+ else:
66
+ state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
67
+
68
+ p_data_fp32 = p
69
+ if p.dtype in {torch.float16, torch.bfloat16}:
70
+ p_data_fp32 = p_data_fp32.float()
71
+
72
+ state["step"] += 1
73
+ state["RMS"] = Adafactor._rms(p_data_fp32)
74
+ lr = Adafactor._get_lr(group, state)
75
+
76
+ beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
77
+ update = (grad**2) + group["eps"][0]
78
+ if factored:
79
+ exp_avg_sq_row = state["exp_avg_sq_row"]
80
+ exp_avg_sq_col = state["exp_avg_sq_col"]
81
+
82
+ exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
83
+ exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
84
+
85
+ # Approximation of exponential moving average of square of gradient
86
+ update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
87
+ update.mul_(grad)
88
+ else:
89
+ exp_avg_sq = state["exp_avg_sq"]
90
+
91
+ exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
92
+ update = exp_avg_sq.rsqrt().mul_(grad)
93
+
94
+ update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
95
+ update.mul_(lr)
96
+
97
+ if use_first_moment:
98
+ exp_avg = state["exp_avg"]
99
+ exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
100
+ update = exp_avg
101
+
102
+ if group["weight_decay"] != 0:
103
+ p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
104
+
105
+ p_data_fp32.add_(-update)
106
+
107
+ # if p.dtype in {torch.float16, torch.bfloat16}:
108
+ # p.copy_(p_data_fp32)
109
+
110
+ if p.dtype == torch.bfloat16:
111
+ copy_stochastic_(p, p_data_fp32)
112
+ elif p.dtype == torch.float16:
113
+ p.copy_(p_data_fp32)
114
+
115
+
116
+ @torch.no_grad()
117
+ def adafactor_step(self, closure=None):
118
+ """
119
+ Performs a single optimization step
120
+
121
+ Arguments:
122
+ closure (callable, optional): A closure that reevaluates the model
123
+ and returns the loss.
124
+ """
125
+ loss = None
126
+ if closure is not None:
127
+ loss = closure()
128
+
129
+ for group in self.param_groups:
130
+ for p in group["params"]:
131
+ adafactor_step_param(self, p, group)
132
+
133
+ return loss
134
+
135
+
136
+ def patch_adafactor_fused(optimizer: Adafactor):
137
+ optimizer.step_param = adafactor_step_param.__get__(optimizer)
138
+ optimizer.step = adafactor_step.__get__(optimizer)
library/attention_processors.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any
3
+ from einops import rearrange
4
+ import torch
5
+ from diffusers.models.attention_processor import Attention
6
+
7
+
8
+ # flash attention forwards and backwards
9
+
10
+ # https://arxiv.org/abs/2205.14135
11
+
12
+ EPSILON = 1e-6
13
+
14
+
15
+ class FlashAttentionFunction(torch.autograd.function.Function):
16
+ @staticmethod
17
+ @torch.no_grad()
18
+ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
19
+ """Algorithm 2 in the paper"""
20
+
21
+ device = q.device
22
+ dtype = q.dtype
23
+ max_neg_value = -torch.finfo(q.dtype).max
24
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
25
+
26
+ o = torch.zeros_like(q)
27
+ all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
28
+ all_row_maxes = torch.full(
29
+ (*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device
30
+ )
31
+
32
+ scale = q.shape[-1] ** -0.5
33
+
34
+ if mask is None:
35
+ mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
36
+ else:
37
+ mask = rearrange(mask, "b n -> b 1 1 n")
38
+ mask = mask.split(q_bucket_size, dim=-1)
39
+
40
+ row_splits = zip(
41
+ q.split(q_bucket_size, dim=-2),
42
+ o.split(q_bucket_size, dim=-2),
43
+ mask,
44
+ all_row_sums.split(q_bucket_size, dim=-2),
45
+ all_row_maxes.split(q_bucket_size, dim=-2),
46
+ )
47
+
48
+ for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
49
+ q_start_index = ind * q_bucket_size - qk_len_diff
50
+
51
+ col_splits = zip(
52
+ k.split(k_bucket_size, dim=-2),
53
+ v.split(k_bucket_size, dim=-2),
54
+ )
55
+
56
+ for k_ind, (kc, vc) in enumerate(col_splits):
57
+ k_start_index = k_ind * k_bucket_size
58
+
59
+ attn_weights = (
60
+ torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
61
+ )
62
+
63
+ if row_mask is not None:
64
+ attn_weights.masked_fill_(~row_mask, max_neg_value)
65
+
66
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
67
+ causal_mask = torch.ones(
68
+ (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
69
+ ).triu(q_start_index - k_start_index + 1)
70
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
71
+
72
+ block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
73
+ attn_weights -= block_row_maxes
74
+ exp_weights = torch.exp(attn_weights)
75
+
76
+ if row_mask is not None:
77
+ exp_weights.masked_fill_(~row_mask, 0.0)
78
+
79
+ block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
80
+ min=EPSILON
81
+ )
82
+
83
+ new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
84
+
85
+ exp_values = torch.einsum(
86
+ "... i j, ... j d -> ... i d", exp_weights, vc
87
+ )
88
+
89
+ exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
90
+ exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
91
+
92
+ new_row_sums = (
93
+ exp_row_max_diff * row_sums
94
+ + exp_block_row_max_diff * block_row_sums
95
+ )
96
+
97
+ oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
98
+ (exp_block_row_max_diff / new_row_sums) * exp_values
99
+ )
100
+
101
+ row_maxes.copy_(new_row_maxes)
102
+ row_sums.copy_(new_row_sums)
103
+
104
+ ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
105
+ ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
106
+
107
+ return o
108
+
109
+ @staticmethod
110
+ @torch.no_grad()
111
+ def backward(ctx, do):
112
+ """Algorithm 4 in the paper"""
113
+
114
+ causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
115
+ q, k, v, o, l, m = ctx.saved_tensors
116
+
117
+ device = q.device
118
+
119
+ max_neg_value = -torch.finfo(q.dtype).max
120
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
121
+
122
+ dq = torch.zeros_like(q)
123
+ dk = torch.zeros_like(k)
124
+ dv = torch.zeros_like(v)
125
+
126
+ row_splits = zip(
127
+ q.split(q_bucket_size, dim=-2),
128
+ o.split(q_bucket_size, dim=-2),
129
+ do.split(q_bucket_size, dim=-2),
130
+ mask,
131
+ l.split(q_bucket_size, dim=-2),
132
+ m.split(q_bucket_size, dim=-2),
133
+ dq.split(q_bucket_size, dim=-2),
134
+ )
135
+
136
+ for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
137
+ q_start_index = ind * q_bucket_size - qk_len_diff
138
+
139
+ col_splits = zip(
140
+ k.split(k_bucket_size, dim=-2),
141
+ v.split(k_bucket_size, dim=-2),
142
+ dk.split(k_bucket_size, dim=-2),
143
+ dv.split(k_bucket_size, dim=-2),
144
+ )
145
+
146
+ for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
147
+ k_start_index = k_ind * k_bucket_size
148
+
149
+ attn_weights = (
150
+ torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
151
+ )
152
+
153
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
154
+ causal_mask = torch.ones(
155
+ (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
156
+ ).triu(q_start_index - k_start_index + 1)
157
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
158
+
159
+ exp_attn_weights = torch.exp(attn_weights - mc)
160
+
161
+ if row_mask is not None:
162
+ exp_attn_weights.masked_fill_(~row_mask, 0.0)
163
+
164
+ p = exp_attn_weights / lc
165
+
166
+ dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
167
+ dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
168
+
169
+ D = (doc * oc).sum(dim=-1, keepdims=True)
170
+ ds = p * scale * (dp - D)
171
+
172
+ dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
173
+ dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
174
+
175
+ dqc.add_(dq_chunk)
176
+ dkc.add_(dk_chunk)
177
+ dvc.add_(dv_chunk)
178
+
179
+ return dq, dk, dv, None, None, None, None
180
+
181
+
182
+ class FlashAttnProcessor:
183
+ def __call__(
184
+ self,
185
+ attn: Attention,
186
+ hidden_states,
187
+ encoder_hidden_states=None,
188
+ attention_mask=None,
189
+ ) -> Any:
190
+ q_bucket_size = 512
191
+ k_bucket_size = 1024
192
+
193
+ h = attn.heads
194
+ q = attn.to_q(hidden_states)
195
+
196
+ encoder_hidden_states = (
197
+ encoder_hidden_states
198
+ if encoder_hidden_states is not None
199
+ else hidden_states
200
+ )
201
+ encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype)
202
+
203
+ if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None:
204
+ context_k, context_v = attn.hypernetwork.forward(
205
+ hidden_states, encoder_hidden_states
206
+ )
207
+ context_k = context_k.to(hidden_states.dtype)
208
+ context_v = context_v.to(hidden_states.dtype)
209
+ else:
210
+ context_k = encoder_hidden_states
211
+ context_v = encoder_hidden_states
212
+
213
+ k = attn.to_k(context_k)
214
+ v = attn.to_v(context_v)
215
+ del encoder_hidden_states, hidden_states
216
+
217
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
218
+
219
+ out = FlashAttentionFunction.apply(
220
+ q, k, v, attention_mask, False, q_bucket_size, k_bucket_size
221
+ )
222
+
223
+ out = rearrange(out, "b h n d -> b n (h d)")
224
+
225
+ out = attn.to_out[0](out)
226
+ out = attn.to_out[1](out)
227
+ return out
library/config_util.py ADDED
@@ -0,0 +1,717 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from dataclasses import (
3
+ asdict,
4
+ dataclass,
5
+ )
6
+ import functools
7
+ import random
8
+ from textwrap import dedent, indent
9
+ import json
10
+ from pathlib import Path
11
+
12
+ # from toolz import curry
13
+ from typing import Dict, List, Optional, Sequence, Tuple, Union
14
+
15
+ import toml
16
+ import voluptuous
17
+ from voluptuous import (
18
+ Any,
19
+ ExactSequence,
20
+ MultipleInvalid,
21
+ Object,
22
+ Required,
23
+ Schema,
24
+ )
25
+
26
+
27
+ from . import train_util
28
+ from .train_util import (
29
+ DreamBoothSubset,
30
+ FineTuningSubset,
31
+ ControlNetSubset,
32
+ DreamBoothDataset,
33
+ FineTuningDataset,
34
+ ControlNetDataset,
35
+ DatasetGroup,
36
+ )
37
+ from .utils import setup_logging
38
+
39
+ setup_logging()
40
+ import logging
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ def add_config_arguments(parser: argparse.ArgumentParser):
46
+ parser.add_argument(
47
+ "--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
48
+ )
49
+
50
+
51
+ # TODO: inherit Params class in Subset, Dataset
52
+
53
+
54
+ @dataclass
55
+ class BaseSubsetParams:
56
+ image_dir: Optional[str] = None
57
+ num_repeats: int = 1
58
+ shuffle_caption: bool = False
59
+ caption_separator: str = (",",)
60
+ keep_tokens: int = 0
61
+ keep_tokens_separator: str = (None,)
62
+ secondary_separator: Optional[str] = None
63
+ enable_wildcard: bool = False
64
+ color_aug: bool = False
65
+ flip_aug: bool = False
66
+ face_crop_aug_range: Optional[Tuple[float, float]] = None
67
+ random_crop: bool = False
68
+ caption_prefix: Optional[str] = None
69
+ caption_suffix: Optional[str] = None
70
+ caption_dropout_rate: float = 0.0
71
+ caption_dropout_every_n_epochs: int = 0
72
+ caption_tag_dropout_rate: float = 0.0
73
+ token_warmup_min: int = 1
74
+ token_warmup_step: float = 0
75
+ custom_attributes: Optional[Dict[str, Any]] = None
76
+
77
+
78
+ @dataclass
79
+ class DreamBoothSubsetParams(BaseSubsetParams):
80
+ is_reg: bool = False
81
+ class_tokens: Optional[str] = None
82
+ caption_extension: str = ".caption"
83
+ cache_info: bool = False
84
+ alpha_mask: bool = False
85
+
86
+
87
+ @dataclass
88
+ class FineTuningSubsetParams(BaseSubsetParams):
89
+ metadata_file: Optional[str] = None
90
+ alpha_mask: bool = False
91
+
92
+
93
+ @dataclass
94
+ class ControlNetSubsetParams(BaseSubsetParams):
95
+ conditioning_data_dir: str = None
96
+ caption_extension: str = ".caption"
97
+ cache_info: bool = False
98
+
99
+
100
+ @dataclass
101
+ class BaseDatasetParams:
102
+ resolution: Optional[Tuple[int, int]] = None
103
+ network_multiplier: float = 1.0
104
+ debug_dataset: bool = False
105
+
106
+
107
+ @dataclass
108
+ class DreamBoothDatasetParams(BaseDatasetParams):
109
+ batch_size: int = 1
110
+ enable_bucket: bool = False
111
+ min_bucket_reso: int = 256
112
+ max_bucket_reso: int = 1024
113
+ bucket_reso_steps: int = 64
114
+ bucket_no_upscale: bool = False
115
+ prior_loss_weight: float = 1.0
116
+
117
+
118
+ @dataclass
119
+ class FineTuningDatasetParams(BaseDatasetParams):
120
+ batch_size: int = 1
121
+ enable_bucket: bool = False
122
+ min_bucket_reso: int = 256
123
+ max_bucket_reso: int = 1024
124
+ bucket_reso_steps: int = 64
125
+ bucket_no_upscale: bool = False
126
+
127
+
128
+ @dataclass
129
+ class ControlNetDatasetParams(BaseDatasetParams):
130
+ batch_size: int = 1
131
+ enable_bucket: bool = False
132
+ min_bucket_reso: int = 256
133
+ max_bucket_reso: int = 1024
134
+ bucket_reso_steps: int = 64
135
+ bucket_no_upscale: bool = False
136
+
137
+
138
+ @dataclass
139
+ class SubsetBlueprint:
140
+ params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
141
+
142
+
143
+ @dataclass
144
+ class DatasetBlueprint:
145
+ is_dreambooth: bool
146
+ is_controlnet: bool
147
+ params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
148
+ subsets: Sequence[SubsetBlueprint]
149
+
150
+
151
+ @dataclass
152
+ class DatasetGroupBlueprint:
153
+ datasets: Sequence[DatasetBlueprint]
154
+
155
+
156
+ @dataclass
157
+ class Blueprint:
158
+ dataset_group: DatasetGroupBlueprint
159
+
160
+
161
+ class ConfigSanitizer:
162
+ # @curry
163
+ @staticmethod
164
+ def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
165
+ Schema(ExactSequence([klass, klass]))(value)
166
+ return tuple(value)
167
+
168
+ # @curry
169
+ @staticmethod
170
+ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
171
+ Schema(Any(klass, ExactSequence([klass, klass])))(value)
172
+ try:
173
+ Schema(klass)(value)
174
+ return (value, value)
175
+ except:
176
+ return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
177
+
178
+ # subset schema
179
+ SUBSET_ASCENDABLE_SCHEMA = {
180
+ "color_aug": bool,
181
+ "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float),
182
+ "flip_aug": bool,
183
+ "num_repeats": int,
184
+ "random_crop": bool,
185
+ "shuffle_caption": bool,
186
+ "keep_tokens": int,
187
+ "keep_tokens_separator": str,
188
+ "secondary_separator": str,
189
+ "caption_separator": str,
190
+ "enable_wildcard": bool,
191
+ "token_warmup_min": int,
192
+ "token_warmup_step": Any(float, int),
193
+ "caption_prefix": str,
194
+ "caption_suffix": str,
195
+ "custom_attributes": dict,
196
+ }
197
+ # DO means DropOut
198
+ DO_SUBSET_ASCENDABLE_SCHEMA = {
199
+ "caption_dropout_every_n_epochs": int,
200
+ "caption_dropout_rate": Any(float, int),
201
+ "caption_tag_dropout_rate": Any(float, int),
202
+ }
203
+ # DB means DreamBooth
204
+ DB_SUBSET_ASCENDABLE_SCHEMA = {
205
+ "caption_extension": str,
206
+ "class_tokens": str,
207
+ "cache_info": bool,
208
+ }
209
+ DB_SUBSET_DISTINCT_SCHEMA = {
210
+ Required("image_dir"): str,
211
+ "is_reg": bool,
212
+ "alpha_mask": bool,
213
+ }
214
+ # FT means FineTuning
215
+ FT_SUBSET_DISTINCT_SCHEMA = {
216
+ Required("metadata_file"): str,
217
+ "image_dir": str,
218
+ "alpha_mask": bool,
219
+ }
220
+ CN_SUBSET_ASCENDABLE_SCHEMA = {
221
+ "caption_extension": str,
222
+ "cache_info": bool,
223
+ }
224
+ CN_SUBSET_DISTINCT_SCHEMA = {
225
+ Required("image_dir"): str,
226
+ Required("conditioning_data_dir"): str,
227
+ }
228
+
229
+ # datasets schema
230
+ DATASET_ASCENDABLE_SCHEMA = {
231
+ "batch_size": int,
232
+ "bucket_no_upscale": bool,
233
+ "bucket_reso_steps": int,
234
+ "enable_bucket": bool,
235
+ "max_bucket_reso": int,
236
+ "min_bucket_reso": int,
237
+ "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
238
+ "network_multiplier": float,
239
+ }
240
+
241
+ # options handled by argparse but not handled by user config
242
+ ARGPARSE_SPECIFIC_SCHEMA = {
243
+ "debug_dataset": bool,
244
+ "max_token_length": Any(None, int),
245
+ "prior_loss_weight": Any(float, int),
246
+ }
247
+ # for handling default None value of argparse
248
+ ARGPARSE_NULLABLE_OPTNAMES = [
249
+ "face_crop_aug_range",
250
+ "resolution",
251
+ ]
252
+ # prepare map because option name may differ among argparse and user config
253
+ ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
254
+ "train_batch_size": "batch_size",
255
+ "dataset_repeats": "num_repeats",
256
+ }
257
+
258
+ def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
259
+ assert support_dreambooth or support_finetuning or support_controlnet, (
260
+ "Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
261
+ + " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。"
262
+ )
263
+
264
+ self.db_subset_schema = self.__merge_dict(
265
+ self.SUBSET_ASCENDABLE_SCHEMA,
266
+ self.DB_SUBSET_DISTINCT_SCHEMA,
267
+ self.DB_SUBSET_ASCENDABLE_SCHEMA,
268
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
269
+ )
270
+
271
+ self.ft_subset_schema = self.__merge_dict(
272
+ self.SUBSET_ASCENDABLE_SCHEMA,
273
+ self.FT_SUBSET_DISTINCT_SCHEMA,
274
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
275
+ )
276
+
277
+ self.cn_subset_schema = self.__merge_dict(
278
+ self.SUBSET_ASCENDABLE_SCHEMA,
279
+ self.CN_SUBSET_DISTINCT_SCHEMA,
280
+ self.CN_SUBSET_ASCENDABLE_SCHEMA,
281
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
282
+ )
283
+
284
+ self.db_dataset_schema = self.__merge_dict(
285
+ self.DATASET_ASCENDABLE_SCHEMA,
286
+ self.SUBSET_ASCENDABLE_SCHEMA,
287
+ self.DB_SUBSET_ASCENDABLE_SCHEMA,
288
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
289
+ {"subsets": [self.db_subset_schema]},
290
+ )
291
+
292
+ self.ft_dataset_schema = self.__merge_dict(
293
+ self.DATASET_ASCENDABLE_SCHEMA,
294
+ self.SUBSET_ASCENDABLE_SCHEMA,
295
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
296
+ {"subsets": [self.ft_subset_schema]},
297
+ )
298
+
299
+ self.cn_dataset_schema = self.__merge_dict(
300
+ self.DATASET_ASCENDABLE_SCHEMA,
301
+ self.SUBSET_ASCENDABLE_SCHEMA,
302
+ self.CN_SUBSET_ASCENDABLE_SCHEMA,
303
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
304
+ {"subsets": [self.cn_subset_schema]},
305
+ )
306
+
307
+ if support_dreambooth and support_finetuning:
308
+
309
+ def validate_flex_dataset(dataset_config: dict):
310
+ subsets_config = dataset_config.get("subsets", [])
311
+
312
+ if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]):
313
+ return Schema(self.cn_dataset_schema)(dataset_config)
314
+ # check dataset meets FT style
315
+ # NOTE: all FT subsets should have "metadata_file"
316
+ elif all(["metadata_file" in subset for subset in subsets_config]):
317
+ return Schema(self.ft_dataset_schema)(dataset_config)
318
+ # check dataset meets DB style
319
+ # NOTE: all DB subsets should have no "metadata_file"
320
+ elif all(["metadata_file" not in subset for subset in subsets_config]):
321
+ return Schema(self.db_dataset_schema)(dataset_config)
322
+ else:
323
+ raise voluptuous.Invalid(
324
+ "DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。"
325
+ )
326
+
327
+ self.dataset_schema = validate_flex_dataset
328
+ elif support_dreambooth:
329
+ if support_controlnet:
330
+ self.dataset_schema = self.cn_dataset_schema
331
+ else:
332
+ self.dataset_schema = self.db_dataset_schema
333
+ elif support_finetuning:
334
+ self.dataset_schema = self.ft_dataset_schema
335
+ elif support_controlnet:
336
+ self.dataset_schema = self.cn_dataset_schema
337
+
338
+ self.general_schema = self.__merge_dict(
339
+ self.DATASET_ASCENDABLE_SCHEMA,
340
+ self.SUBSET_ASCENDABLE_SCHEMA,
341
+ self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
342
+ self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {},
343
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
344
+ )
345
+
346
+ self.user_config_validator = Schema(
347
+ {
348
+ "general": self.general_schema,
349
+ "datasets": [self.dataset_schema],
350
+ }
351
+ )
352
+
353
+ self.argparse_schema = self.__merge_dict(
354
+ self.general_schema,
355
+ self.ARGPARSE_SPECIFIC_SCHEMA,
356
+ {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES},
357
+ {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()},
358
+ )
359
+
360
+ self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
361
+
362
+ def sanitize_user_config(self, user_config: dict) -> dict:
363
+ try:
364
+ return self.user_config_validator(user_config)
365
+ except MultipleInvalid:
366
+ # TODO: エラー発生時のメッセージをわかりやすくする
367
+ logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
368
+ raise
369
+
370
+ # NOTE: In nature, argument parser result is not needed to be sanitize
371
+ # However this will help us to detect program bug
372
+ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
373
+ try:
374
+ return self.argparse_config_validator(argparse_namespace)
375
+ except MultipleInvalid:
376
+ # XXX: this should be a bug
377
+ logger.error(
378
+ "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
379
+ )
380
+ raise
381
+
382
+ # NOTE: value would be overwritten by latter dict if there is already the same key
383
+ @staticmethod
384
+ def __merge_dict(*dict_list: dict) -> dict:
385
+ merged = {}
386
+ for schema in dict_list:
387
+ # merged |= schema
388
+ for k, v in schema.items():
389
+ merged[k] = v
390
+ return merged
391
+
392
+
393
+ class BlueprintGenerator:
394
+ BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
395
+
396
+ def __init__(self, sanitizer: ConfigSanitizer):
397
+ self.sanitizer = sanitizer
398
+
399
+ # runtime_params is for parameters which is only configurable on runtime, such as tokenizer
400
+ def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
401
+ sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
402
+ sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
403
+
404
+ # convert argparse namespace to dict like config
405
+ # NOTE: it is ok to have extra entries in dict
406
+ optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME
407
+ argparse_config = {
408
+ optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()
409
+ }
410
+
411
+ general_config = sanitized_user_config.get("general", {})
412
+
413
+ dataset_blueprints = []
414
+ for dataset_config in sanitized_user_config.get("datasets", []):
415
+ # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
416
+ subsets = dataset_config.get("subsets", [])
417
+ is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
418
+ is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets])
419
+ if is_controlnet:
420
+ subset_params_klass = ControlNetSubsetParams
421
+ dataset_params_klass = ControlNetDatasetParams
422
+ elif is_dreambooth:
423
+ subset_params_klass = DreamBoothSubsetParams
424
+ dataset_params_klass = DreamBoothDatasetParams
425
+ else:
426
+ subset_params_klass = FineTuningSubsetParams
427
+ dataset_params_klass = FineTuningDatasetParams
428
+
429
+ subset_blueprints = []
430
+ for subset_config in subsets:
431
+ params = self.generate_params_by_fallbacks(
432
+ subset_params_klass, [subset_config, dataset_config, general_config, argparse_config, runtime_params]
433
+ )
434
+ subset_blueprints.append(SubsetBlueprint(params))
435
+
436
+ params = self.generate_params_by_fallbacks(
437
+ dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
438
+ )
439
+ dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints))
440
+
441
+ dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
442
+
443
+ return Blueprint(dataset_group_blueprint)
444
+
445
+ @staticmethod
446
+ def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
447
+ name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
448
+ search_value = BlueprintGenerator.search_value
449
+ default_params = asdict(param_klass())
450
+ param_names = default_params.keys()
451
+
452
+ params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
453
+
454
+ return param_klass(**params)
455
+
456
+ @staticmethod
457
+ def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
458
+ for cand in fallbacks:
459
+ value = cand.get(key)
460
+ if value is not None:
461
+ return value
462
+
463
+ return default_value
464
+
465
+
466
+ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
467
+ datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
468
+
469
+ for dataset_blueprint in dataset_group_blueprint.datasets:
470
+ if dataset_blueprint.is_controlnet:
471
+ subset_klass = ControlNetSubset
472
+ dataset_klass = ControlNetDataset
473
+ elif dataset_blueprint.is_dreambooth:
474
+ subset_klass = DreamBoothSubset
475
+ dataset_klass = DreamBoothDataset
476
+ else:
477
+ subset_klass = FineTuningSubset
478
+ dataset_klass = FineTuningDataset
479
+
480
+ subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
481
+ dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
482
+ datasets.append(dataset)
483
+
484
+ # print info
485
+ info = ""
486
+ for i, dataset in enumerate(datasets):
487
+ is_dreambooth = isinstance(dataset, DreamBoothDataset)
488
+ is_controlnet = isinstance(dataset, ControlNetDataset)
489
+ info += dedent(
490
+ f"""\
491
+ [Dataset {i}]
492
+ batch_size: {dataset.batch_size}
493
+ resolution: {(dataset.width, dataset.height)}
494
+ enable_bucket: {dataset.enable_bucket}
495
+ network_multiplier: {dataset.network_multiplier}
496
+ """
497
+ )
498
+
499
+ if dataset.enable_bucket:
500
+ info += indent(
501
+ dedent(
502
+ f"""\
503
+ min_bucket_reso: {dataset.min_bucket_reso}
504
+ max_bucket_reso: {dataset.max_bucket_reso}
505
+ bucket_reso_steps: {dataset.bucket_reso_steps}
506
+ bucket_no_upscale: {dataset.bucket_no_upscale}
507
+ \n"""
508
+ ),
509
+ " ",
510
+ )
511
+ else:
512
+ info += "\n"
513
+
514
+ for j, subset in enumerate(dataset.subsets):
515
+ info += indent(
516
+ dedent(
517
+ f"""\
518
+ [Subset {j} of Dataset {i}]
519
+ image_dir: "{subset.image_dir}"
520
+ image_count: {subset.img_count}
521
+ num_repeats: {subset.num_repeats}
522
+ shuffle_caption: {subset.shuffle_caption}
523
+ keep_tokens: {subset.keep_tokens}
524
+ keep_tokens_separator: {subset.keep_tokens_separator}
525
+ caption_separator: {subset.caption_separator}
526
+ secondary_separator: {subset.secondary_separator}
527
+ enable_wildcard: {subset.enable_wildcard}
528
+ caption_dropout_rate: {subset.caption_dropout_rate}
529
+ caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
530
+ caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
531
+ caption_prefix: {subset.caption_prefix}
532
+ caption_suffix: {subset.caption_suffix}
533
+ color_aug: {subset.color_aug}
534
+ flip_aug: {subset.flip_aug}
535
+ face_crop_aug_range: {subset.face_crop_aug_range}
536
+ random_crop: {subset.random_crop}
537
+ token_warmup_min: {subset.token_warmup_min}
538
+ token_warmup_step: {subset.token_warmup_step}
539
+ alpha_mask: {subset.alpha_mask}
540
+ custom_attributes: {subset.custom_attributes}
541
+ """
542
+ ),
543
+ " ",
544
+ )
545
+
546
+ if is_dreambooth:
547
+ info += indent(
548
+ dedent(
549
+ f"""\
550
+ is_reg: {subset.is_reg}
551
+ class_tokens: {subset.class_tokens}
552
+ caption_extension: {subset.caption_extension}
553
+ \n"""
554
+ ),
555
+ " ",
556
+ )
557
+ elif not is_controlnet:
558
+ info += indent(
559
+ dedent(
560
+ f"""\
561
+ metadata_file: {subset.metadata_file}
562
+ \n"""
563
+ ),
564
+ " ",
565
+ )
566
+
567
+ logger.info(f"{info}")
568
+
569
+ # make buckets first because it determines the length of dataset
570
+ # and set the same seed for all datasets
571
+ seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
572
+ for i, dataset in enumerate(datasets):
573
+ logger.info(f"[Dataset {i}]")
574
+ dataset.make_buckets()
575
+ dataset.set_seed(seed)
576
+
577
+ return DatasetGroup(datasets)
578
+
579
+
580
+ def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
581
+ def extract_dreambooth_params(name: str) -> Tuple[int, str]:
582
+ tokens = name.split("_")
583
+ try:
584
+ n_repeats = int(tokens[0])
585
+ except ValueError as e:
586
+ logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
587
+ return 0, ""
588
+ caption_by_folder = "_".join(tokens[1:])
589
+ return n_repeats, caption_by_folder
590
+
591
+ def generate(base_dir: Optional[str], is_reg: bool):
592
+ if base_dir is None:
593
+ return []
594
+
595
+ base_dir: Path = Path(base_dir)
596
+ if not base_dir.is_dir():
597
+ return []
598
+
599
+ subsets_config = []
600
+ for subdir in base_dir.iterdir():
601
+ if not subdir.is_dir():
602
+ continue
603
+
604
+ num_repeats, class_tokens = extract_dreambooth_params(subdir.name)
605
+ if num_repeats < 1:
606
+ continue
607
+
608
+ subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens}
609
+ subsets_config.append(subset_config)
610
+
611
+ return subsets_config
612
+
613
+ subsets_config = []
614
+ subsets_config += generate(train_data_dir, False)
615
+ subsets_config += generate(reg_data_dir, True)
616
+
617
+ return subsets_config
618
+
619
+
620
+ def generate_controlnet_subsets_config_by_subdirs(
621
+ train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"
622
+ ):
623
+ def generate(base_dir: Optional[str]):
624
+ if base_dir is None:
625
+ return []
626
+
627
+ base_dir: Path = Path(base_dir)
628
+ if not base_dir.is_dir():
629
+ return []
630
+
631
+ subsets_config = []
632
+ subset_config = {
633
+ "image_dir": train_data_dir,
634
+ "conditioning_data_dir": conditioning_data_dir,
635
+ "caption_extension": caption_extension,
636
+ "num_repeats": 1,
637
+ }
638
+ subsets_config.append(subset_config)
639
+
640
+ return subsets_config
641
+
642
+ subsets_config = []
643
+ subsets_config += generate(train_data_dir)
644
+
645
+ return subsets_config
646
+
647
+
648
+ def load_user_config(file: str) -> dict:
649
+ file_path: Path = Path(file)
650
+ if not file_path.is_file():
651
+ #raise ValueError(f"file not found / ファイルが見つかりません: {file}")
652
+ return toml.loads(file)
653
+
654
+ if file_path.name.lower().endswith(".json"):
655
+ try:
656
+ with open(file, "r") as f:
657
+ config = json.load(f)
658
+ except Exception:
659
+ logger.error(
660
+ f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
661
+ )
662
+ raise
663
+ elif file_path.name.lower().endswith(".toml"):
664
+ try:
665
+ config = toml.load(file_path)
666
+ except Exception:
667
+ logger.error(
668
+ f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
669
+ )
670
+ raise
671
+ else:
672
+ raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file_path}")
673
+
674
+ return config
675
+
676
+
677
+ # for config test
678
+ if __name__ == "__main__":
679
+ parser = argparse.ArgumentParser()
680
+ parser.add_argument("--support_dreambooth", action="store_true")
681
+ parser.add_argument("--support_finetuning", action="store_true")
682
+ parser.add_argument("--support_controlnet", action="store_true")
683
+ parser.add_argument("--support_dropout", action="store_true")
684
+ parser.add_argument("dataset_config")
685
+ config_args, remain = parser.parse_known_args()
686
+
687
+ parser = argparse.ArgumentParser()
688
+ train_util.add_dataset_arguments(
689
+ parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout
690
+ )
691
+ train_util.add_training_arguments(parser, config_args.support_dreambooth)
692
+ argparse_namespace = parser.parse_args(remain)
693
+ train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
694
+
695
+ logger.info("[argparse_namespace]")
696
+ logger.info(f"{vars(argparse_namespace)}")
697
+
698
+ user_config = load_user_config(config_args.dataset_config)
699
+
700
+ logger.info("")
701
+ logger.info("[user_config]")
702
+ logger.info(f"{user_config}")
703
+
704
+ sanitizer = ConfigSanitizer(
705
+ config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
706
+ )
707
+ sanitized_user_config = sanitizer.sanitize_user_config(user_config)
708
+
709
+ logger.info("")
710
+ logger.info("[sanitized_user_config]")
711
+ logger.info(f"{sanitized_user_config}")
712
+
713
+ blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
714
+
715
+ logger.info("")
716
+ logger.info("[blueprint]")
717
+ logger.info(f"{blueprint}")
library/custom_offloading_utils.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from concurrent.futures import ThreadPoolExecutor
2
+ import time
3
+ from typing import Optional
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from .device_utils import clean_memory_on_device
8
+
9
+
10
+ def synchronize_device(device: torch.device):
11
+ if device.type == "cuda":
12
+ torch.cuda.synchronize()
13
+ elif device.type == "xpu":
14
+ torch.xpu.synchronize()
15
+ elif device.type == "mps":
16
+ torch.mps.synchronize()
17
+
18
+
19
+ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
20
+ assert layer_to_cpu.__class__ == layer_to_cuda.__class__
21
+
22
+ weight_swap_jobs = []
23
+
24
+ # This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
25
+ # for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
26
+ # print(module_to_cpu.__class__, module_to_cuda.__class__)
27
+ # if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
28
+ # weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
29
+
30
+ modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()}
31
+ for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules():
32
+ if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None:
33
+ module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None)
34
+ if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape:
35
+ weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
36
+ else:
37
+ if module_to_cuda.weight.data.device.type != device.type:
38
+ # print(
39
+ # f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
40
+ # )
41
+ module_to_cuda.weight.data = module_to_cuda.weight.data.to(device)
42
+
43
+ torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
44
+
45
+ stream = torch.cuda.Stream()
46
+ with torch.cuda.stream(stream):
47
+ # cuda to cpu
48
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
49
+ cuda_data_view.record_stream(stream)
50
+ module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
51
+
52
+ stream.synchronize()
53
+
54
+ # cpu to cuda
55
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
56
+ cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
57
+ module_to_cuda.weight.data = cuda_data_view
58
+
59
+ stream.synchronize()
60
+ torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
61
+
62
+
63
+ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
64
+ """
65
+ not tested
66
+ """
67
+ assert layer_to_cpu.__class__ == layer_to_cuda.__class__
68
+
69
+ weight_swap_jobs = []
70
+ for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
71
+ if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
72
+ weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
73
+
74
+ # device to cpu
75
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
76
+ module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
77
+
78
+ synchronize_device()
79
+
80
+ # cpu to device
81
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
82
+ cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
83
+ module_to_cuda.weight.data = cuda_data_view
84
+
85
+ synchronize_device()
86
+
87
+
88
+ def weighs_to_device(layer: nn.Module, device: torch.device):
89
+ for module in layer.modules():
90
+ if hasattr(module, "weight") and module.weight is not None:
91
+ module.weight.data = module.weight.data.to(device, non_blocking=True)
92
+
93
+
94
+ class Offloader:
95
+ """
96
+ common offloading class
97
+ """
98
+
99
+ def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
100
+ self.num_blocks = num_blocks
101
+ self.blocks_to_swap = blocks_to_swap
102
+ self.device = device
103
+ self.debug = debug
104
+
105
+ self.thread_pool = ThreadPoolExecutor(max_workers=1)
106
+ self.futures = {}
107
+ self.cuda_available = device.type == "cuda"
108
+
109
+ def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module):
110
+ if self.cuda_available:
111
+ swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda)
112
+ else:
113
+ swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda)
114
+
115
+ def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda):
116
+ def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
117
+ if self.debug:
118
+ start_time = time.perf_counter()
119
+ print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}")
120
+
121
+ self.swap_weight_devices(block_to_cpu, block_to_cuda)
122
+
123
+ if self.debug:
124
+ print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
125
+ return bidx_to_cpu, bidx_to_cuda # , event
126
+
127
+ block_to_cpu = blocks[block_idx_to_cpu]
128
+ block_to_cuda = blocks[block_idx_to_cuda]
129
+
130
+ self.futures[block_idx_to_cuda] = self.thread_pool.submit(
131
+ move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda
132
+ )
133
+
134
+ def _wait_blocks_move(self, block_idx):
135
+ if block_idx not in self.futures:
136
+ return
137
+
138
+ if self.debug:
139
+ print(f"Wait for block {block_idx}")
140
+ start_time = time.perf_counter()
141
+
142
+ future = self.futures.pop(block_idx)
143
+ _, bidx_to_cuda = future.result()
144
+
145
+ assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
146
+
147
+ if self.debug:
148
+ print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
149
+
150
+
151
+ class ModelOffloader(Offloader):
152
+ """
153
+ supports forward offloading
154
+ """
155
+
156
+ def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
157
+ super().__init__(num_blocks, blocks_to_swap, device, debug)
158
+
159
+ # register backward hooks
160
+ self.remove_handles = []
161
+ for i, block in enumerate(blocks):
162
+ hook = self.create_backward_hook(blocks, i)
163
+ if hook is not None:
164
+ handle = block.register_full_backward_hook(hook)
165
+ self.remove_handles.append(handle)
166
+
167
+ def __del__(self):
168
+ for handle in self.remove_handles:
169
+ handle.remove()
170
+
171
+ def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
172
+ # -1 for 0-based index
173
+ num_blocks_propagated = self.num_blocks - block_index - 1
174
+ swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
175
+ waiting = block_index > 0 and block_index <= self.blocks_to_swap
176
+
177
+ if not swapping and not waiting:
178
+ return None
179
+
180
+ # create hook
181
+ block_idx_to_cpu = self.num_blocks - num_blocks_propagated
182
+ block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
183
+ block_idx_to_wait = block_index - 1
184
+
185
+ def backward_hook(module, grad_input, grad_output):
186
+ if self.debug:
187
+ print(f"Backward hook for block {block_index}")
188
+
189
+ if swapping:
190
+ self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
191
+ if waiting:
192
+ self._wait_blocks_move(block_idx_to_wait)
193
+ return None
194
+
195
+ return backward_hook
196
+
197
+ def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
198
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
199
+ return
200
+
201
+ if self.debug:
202
+ print("Prepare block devices before forward")
203
+
204
+ for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
205
+ b.to(self.device)
206
+ weighs_to_device(b, self.device) # make sure weights are on device
207
+
208
+ for b in blocks[self.num_blocks - self.blocks_to_swap :]:
209
+ b.to(self.device) # move block to device first
210
+ weighs_to_device(b, "cpu") # make sure weights are on cpu
211
+
212
+ synchronize_device(self.device)
213
+ clean_memory_on_device(self.device)
214
+
215
+ def wait_for_block(self, block_idx: int):
216
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
217
+ return
218
+ self._wait_blocks_move(block_idx)
219
+
220
+ def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
221
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
222
+ return
223
+ if block_idx >= self.blocks_to_swap:
224
+ return
225
+ block_idx_to_cpu = block_idx
226
+ block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx
227
+ self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
library/custom_train_functions.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import random
4
+ import re
5
+ from typing import List, Optional, Union
6
+ from .utils import setup_logging
7
+
8
+ setup_logging()
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def prepare_scheduler_for_custom_training(noise_scheduler, device):
15
+ if hasattr(noise_scheduler, "all_snr"):
16
+ return
17
+
18
+ alphas_cumprod = noise_scheduler.alphas_cumprod
19
+ sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
20
+ sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
21
+ alpha = sqrt_alphas_cumprod
22
+ sigma = sqrt_one_minus_alphas_cumprod
23
+ all_snr = (alpha / sigma) ** 2
24
+
25
+ noise_scheduler.all_snr = all_snr.to(device)
26
+
27
+
28
+ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
29
+ # fix beta: zero terminal SNR
30
+ logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
31
+
32
+ def enforce_zero_terminal_snr(betas):
33
+ # Convert betas to alphas_bar_sqrt
34
+ alphas = 1 - betas
35
+ alphas_bar = alphas.cumprod(0)
36
+ alphas_bar_sqrt = alphas_bar.sqrt()
37
+
38
+ # Store old values.
39
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
40
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
41
+ # Shift so last timestep is zero.
42
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
43
+ # Scale so first timestep is back to old value.
44
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
45
+
46
+ # Convert alphas_bar_sqrt to betas
47
+ alphas_bar = alphas_bar_sqrt**2
48
+ alphas = alphas_bar[1:] / alphas_bar[:-1]
49
+ alphas = torch.cat([alphas_bar[0:1], alphas])
50
+ betas = 1 - alphas
51
+ return betas
52
+
53
+ betas = noise_scheduler.betas
54
+ betas = enforce_zero_terminal_snr(betas)
55
+ alphas = 1.0 - betas
56
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
57
+
58
+ # logger.info(f"original: {noise_scheduler.betas}")
59
+ # logger.info(f"fixed: {betas}")
60
+
61
+ noise_scheduler.betas = betas
62
+ noise_scheduler.alphas = alphas
63
+ noise_scheduler.alphas_cumprod = alphas_cumprod
64
+
65
+
66
+ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
67
+ snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
68
+ min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
69
+ if v_prediction:
70
+ snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
71
+ else:
72
+ snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
73
+ loss = loss * snr_weight
74
+ return loss
75
+
76
+
77
+ def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
78
+ scale = get_snr_scale(timesteps, noise_scheduler)
79
+ loss = loss * scale
80
+ return loss
81
+
82
+
83
+ def get_snr_scale(timesteps, noise_scheduler):
84
+ snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
85
+ snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
86
+ scale = snr_t / (snr_t + 1)
87
+ # # show debug info
88
+ # logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
89
+ return scale
90
+
91
+
92
+ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
93
+ scale = get_snr_scale(timesteps, noise_scheduler)
94
+ # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
95
+ loss = loss + loss / scale * v_pred_like_loss
96
+ return loss
97
+
98
+
99
+ def apply_debiased_estimation(loss, timesteps, noise_scheduler):
100
+ snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
101
+ snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
102
+ weight = 1 / torch.sqrt(snr_t)
103
+ loss = weight * loss
104
+ return loss
105
+
106
+
107
+ # TODO train_utilと分散しているのでどちらかに寄せる
108
+
109
+
110
+ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
111
+ parser.add_argument(
112
+ "--min_snr_gamma",
113
+ type=float,
114
+ default=None,
115
+ help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
116
+ )
117
+ parser.add_argument(
118
+ "--scale_v_pred_loss_like_noise_pred",
119
+ action="store_true",
120
+ help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
121
+ )
122
+ parser.add_argument(
123
+ "--v_pred_like_loss",
124
+ type=float,
125
+ default=None,
126
+ help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
127
+ )
128
+ parser.add_argument(
129
+ "--debiased_estimation_loss",
130
+ action="store_true",
131
+ help="debiased estimation loss / debiased estimation loss",
132
+ )
133
+ if support_weighted_captions:
134
+ parser.add_argument(
135
+ "--weighted_captions",
136
+ action="store_true",
137
+ default=False,
138
+ help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
139
+ )
140
+
141
+
142
+ re_attention = re.compile(
143
+ r"""
144
+ \\\(|
145
+ \\\)|
146
+ \\\[|
147
+ \\]|
148
+ \\\\|
149
+ \\|
150
+ \(|
151
+ \[|
152
+ :([+-]?[.\d]+)\)|
153
+ \)|
154
+ ]|
155
+ [^\\()\[\]:]+|
156
+ :
157
+ """,
158
+ re.X,
159
+ )
160
+
161
+
162
+ def parse_prompt_attention(text):
163
+ """
164
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
165
+ Accepted tokens are:
166
+ (abc) - increases attention to abc by a multiplier of 1.1
167
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
168
+ [abc] - decreases attention to abc by a multiplier of 1.1
169
+ \( - literal character '('
170
+ \[ - literal character '['
171
+ \) - literal character ')'
172
+ \] - literal character ']'
173
+ \\ - literal character '\'
174
+ anything else - just text
175
+ >>> parse_prompt_attention('normal text')
176
+ [['normal text', 1.0]]
177
+ >>> parse_prompt_attention('an (important) word')
178
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
179
+ >>> parse_prompt_attention('(unbalanced')
180
+ [['unbalanced', 1.1]]
181
+ >>> parse_prompt_attention('\(literal\]')
182
+ [['(literal]', 1.0]]
183
+ >>> parse_prompt_attention('(unnecessary)(parens)')
184
+ [['unnecessaryparens', 1.1]]
185
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
186
+ [['a ', 1.0],
187
+ ['house', 1.5730000000000004],
188
+ [' ', 1.1],
189
+ ['on', 1.0],
190
+ [' a ', 1.1],
191
+ ['hill', 0.55],
192
+ [', sun, ', 1.1],
193
+ ['sky', 1.4641000000000006],
194
+ ['.', 1.1]]
195
+ """
196
+
197
+ res = []
198
+ round_brackets = []
199
+ square_brackets = []
200
+
201
+ round_bracket_multiplier = 1.1
202
+ square_bracket_multiplier = 1 / 1.1
203
+
204
+ def multiply_range(start_position, multiplier):
205
+ for p in range(start_position, len(res)):
206
+ res[p][1] *= multiplier
207
+
208
+ for m in re_attention.finditer(text):
209
+ text = m.group(0)
210
+ weight = m.group(1)
211
+
212
+ if text.startswith("\\"):
213
+ res.append([text[1:], 1.0])
214
+ elif text == "(":
215
+ round_brackets.append(len(res))
216
+ elif text == "[":
217
+ square_brackets.append(len(res))
218
+ elif weight is not None and len(round_brackets) > 0:
219
+ multiply_range(round_brackets.pop(), float(weight))
220
+ elif text == ")" and len(round_brackets) > 0:
221
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
222
+ elif text == "]" and len(square_brackets) > 0:
223
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
224
+ else:
225
+ res.append([text, 1.0])
226
+
227
+ for pos in round_brackets:
228
+ multiply_range(pos, round_bracket_multiplier)
229
+
230
+ for pos in square_brackets:
231
+ multiply_range(pos, square_bracket_multiplier)
232
+
233
+ if len(res) == 0:
234
+ res = [["", 1.0]]
235
+
236
+ # merge runs of identical weights
237
+ i = 0
238
+ while i + 1 < len(res):
239
+ if res[i][1] == res[i + 1][1]:
240
+ res[i][0] += res[i + 1][0]
241
+ res.pop(i + 1)
242
+ else:
243
+ i += 1
244
+
245
+ return res
246
+
247
+
248
+ def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
249
+ r"""
250
+ Tokenize a list of prompts and return its tokens with weights of each token.
251
+
252
+ No padding, starting or ending token is included.
253
+ """
254
+ tokens = []
255
+ weights = []
256
+ truncated = False
257
+ for text in prompt:
258
+ texts_and_weights = parse_prompt_attention(text)
259
+ text_token = []
260
+ text_weight = []
261
+ for word, weight in texts_and_weights:
262
+ # tokenize and discard the starting and the ending token
263
+ token = tokenizer(word).input_ids[1:-1]
264
+ text_token += token
265
+ # copy the weight by length of token
266
+ text_weight += [weight] * len(token)
267
+ # stop if the text is too long (longer than truncation limit)
268
+ if len(text_token) > max_length:
269
+ truncated = True
270
+ break
271
+ # truncate
272
+ if len(text_token) > max_length:
273
+ truncated = True
274
+ text_token = text_token[:max_length]
275
+ text_weight = text_weight[:max_length]
276
+ tokens.append(text_token)
277
+ weights.append(text_weight)
278
+ if truncated:
279
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
280
+ return tokens, weights
281
+
282
+
283
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
284
+ r"""
285
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
286
+ """
287
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
288
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
289
+ for i in range(len(tokens)):
290
+ tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
291
+ if no_boseos_middle:
292
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
293
+ else:
294
+ w = []
295
+ if len(weights[i]) == 0:
296
+ w = [1.0] * weights_length
297
+ else:
298
+ for j in range(max_embeddings_multiples):
299
+ w.append(1.0) # weight for starting token in this chunk
300
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
301
+ w.append(1.0) # weight for ending token in this chunk
302
+ w += [1.0] * (weights_length - len(w))
303
+ weights[i] = w[:]
304
+
305
+ return tokens, weights
306
+
307
+
308
+ def get_unweighted_text_embeddings(
309
+ tokenizer,
310
+ text_encoder,
311
+ text_input: torch.Tensor,
312
+ chunk_length: int,
313
+ clip_skip: int,
314
+ eos: int,
315
+ pad: int,
316
+ no_boseos_middle: Optional[bool] = True,
317
+ ):
318
+ """
319
+ When the length of tokens is a multiple of the capacity of the text encoder,
320
+ it should be split into chunks and sent to the text encoder individually.
321
+ """
322
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
323
+ if max_embeddings_multiples > 1:
324
+ text_embeddings = []
325
+ for i in range(max_embeddings_multiples):
326
+ # extract the i-th chunk
327
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
328
+
329
+ # cover the head and the tail by the starting and the ending tokens
330
+ text_input_chunk[:, 0] = text_input[0, 0]
331
+ if pad == eos: # v1
332
+ text_input_chunk[:, -1] = text_input[0, -1]
333
+ else: # v2
334
+ for j in range(len(text_input_chunk)):
335
+ if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
336
+ text_input_chunk[j, -1] = eos
337
+ if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
338
+ text_input_chunk[j, 1] = eos
339
+
340
+ if clip_skip is None or clip_skip == 1:
341
+ text_embedding = text_encoder(text_input_chunk)[0]
342
+ else:
343
+ enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
344
+ text_embedding = enc_out["hidden_states"][-clip_skip]
345
+ text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
346
+
347
+ if no_boseos_middle:
348
+ if i == 0:
349
+ # discard the ending token
350
+ text_embedding = text_embedding[:, :-1]
351
+ elif i == max_embeddings_multiples - 1:
352
+ # discard the starting token
353
+ text_embedding = text_embedding[:, 1:]
354
+ else:
355
+ # discard both starting and ending tokens
356
+ text_embedding = text_embedding[:, 1:-1]
357
+
358
+ text_embeddings.append(text_embedding)
359
+ text_embeddings = torch.concat(text_embeddings, axis=1)
360
+ else:
361
+ if clip_skip is None or clip_skip == 1:
362
+ text_embeddings = text_encoder(text_input)[0]
363
+ else:
364
+ enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
365
+ text_embeddings = enc_out["hidden_states"][-clip_skip]
366
+ text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
367
+ return text_embeddings
368
+
369
+
370
+ def get_weighted_text_embeddings(
371
+ tokenizer,
372
+ text_encoder,
373
+ prompt: Union[str, List[str]],
374
+ device,
375
+ max_embeddings_multiples: Optional[int] = 3,
376
+ no_boseos_middle: Optional[bool] = False,
377
+ clip_skip=None,
378
+ ):
379
+ r"""
380
+ Prompts can be assigned with local weights using brackets. For example,
381
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
382
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
383
+
384
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
385
+
386
+ Args:
387
+ prompt (`str` or `List[str]`):
388
+ The prompt or prompts to guide the image generation.
389
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
390
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
391
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
392
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
393
+ ending token in each of the chunk in the middle.
394
+ skip_parsing (`bool`, *optional*, defaults to `False`):
395
+ Skip the parsing of brackets.
396
+ skip_weighting (`bool`, *optional*, defaults to `False`):
397
+ Skip the weighting. When the parsing is skipped, it is forced True.
398
+ """
399
+ max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
400
+ if isinstance(prompt, str):
401
+ prompt = [prompt]
402
+
403
+ prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
404
+
405
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
406
+ max_length = max([len(token) for token in prompt_tokens])
407
+
408
+ max_embeddings_multiples = min(
409
+ max_embeddings_multiples,
410
+ (max_length - 1) // (tokenizer.model_max_length - 2) + 1,
411
+ )
412
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
413
+ max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
414
+
415
+ # pad the length of tokens and weights
416
+ bos = tokenizer.bos_token_id
417
+ eos = tokenizer.eos_token_id
418
+ pad = tokenizer.pad_token_id
419
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
420
+ prompt_tokens,
421
+ prompt_weights,
422
+ max_length,
423
+ bos,
424
+ eos,
425
+ no_boseos_middle=no_boseos_middle,
426
+ chunk_length=tokenizer.model_max_length,
427
+ )
428
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
429
+
430
+ # get the embeddings
431
+ text_embeddings = get_unweighted_text_embeddings(
432
+ tokenizer,
433
+ text_encoder,
434
+ prompt_tokens,
435
+ tokenizer.model_max_length,
436
+ clip_skip,
437
+ eos,
438
+ pad,
439
+ no_boseos_middle=no_boseos_middle,
440
+ )
441
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
442
+
443
+ # assign weights to the prompts and normalize in the sense of mean
444
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
445
+ text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
446
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
447
+ text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
448
+
449
+ return text_embeddings
450
+
451
+
452
+ # https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
453
+ def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
454
+ b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
455
+ u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
456
+ for i in range(iterations):
457
+ r = random.random() * 2 + 2 # Rather than always going 2x,
458
+ wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
459
+ noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
460
+ if wn == 1 or hn == 1:
461
+ break # Lowest resolution is 1x1
462
+ return noise / noise.std() # Scaled back to roughly unit variance
463
+
464
+
465
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
466
+ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
467
+ if noise_offset is None:
468
+ return noise
469
+ if adaptive_noise_scale is not None:
470
+ # latent shape: (batch_size, channels, height, width)
471
+ # abs mean value for each channel
472
+ latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
473
+
474
+ # multiply adaptive noise scale to the mean value and add it to the noise offset
475
+ noise_offset = noise_offset + adaptive_noise_scale * latent_mean
476
+ noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
477
+
478
+ noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
479
+ return noise
480
+
481
+
482
+ def apply_masked_loss(loss, batch):
483
+ if "conditioning_images" in batch:
484
+ # conditioning image is -1 to 1. we need to convert it to 0 to 1
485
+ mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
486
+ mask_image = mask_image / 2 + 0.5
487
+ # print(f"conditioning_image: {mask_image.shape}")
488
+ elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
489
+ # alpha mask is 0 to 1
490
+ mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
491
+ # print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
492
+ else:
493
+ return loss
494
+
495
+ # resize to the same size as the loss
496
+ mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
497
+ loss = loss * mask_image
498
+ return loss
499
+
500
+
501
+ """
502
+ ##########################################
503
+ # Perlin Noise
504
+ def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
505
+ delta = (res[0] / shape[0], res[1] / shape[1])
506
+ d = (shape[0] // res[0], shape[1] // res[1])
507
+
508
+ grid = (
509
+ torch.stack(
510
+ torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
511
+ dim=-1,
512
+ )
513
+ % 1
514
+ )
515
+ angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
516
+ gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
517
+
518
+ tile_grads = (
519
+ lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
520
+ .repeat_interleave(d[0], 0)
521
+ .repeat_interleave(d[1], 1)
522
+ )
523
+ dot = lambda grad, shift: (
524
+ torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
525
+ * grad[: shape[0], : shape[1]]
526
+ ).sum(dim=-1)
527
+
528
+ n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
529
+ n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
530
+ n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
531
+ n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
532
+ t = fade(grid[: shape[0], : shape[1]])
533
+ return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
534
+
535
+
536
+ def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
537
+ noise = torch.zeros(shape, device=device)
538
+ frequency = 1
539
+ amplitude = 1
540
+ for _ in range(octaves):
541
+ noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
542
+ frequency *= 2
543
+ amplitude *= persistence
544
+ return noise
545
+
546
+
547
+ def perlin_noise(noise, device, octaves):
548
+ _, c, w, h = noise.shape
549
+ perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
550
+ noise_perlin = []
551
+ for _ in range(c):
552
+ noise_perlin.append(perlin())
553
+ noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
554
+ noise += noise_perlin # broadcast for each batch
555
+ return noise / noise.std() # Scaled back to roughly unit variance
556
+ """
library/deepspeed_utils.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ from accelerate import DeepSpeedPlugin, Accelerator
5
+
6
+ from .utils import setup_logging
7
+
8
+ setup_logging()
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def add_deepspeed_arguments(parser: argparse.ArgumentParser):
15
+ # DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
16
+ parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
17
+ parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.")
18
+ parser.add_argument(
19
+ "--offload_optimizer_device",
20
+ type=str,
21
+ default=None,
22
+ choices=[None, "cpu", "nvme"],
23
+ help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.",
24
+ )
25
+ parser.add_argument(
26
+ "--offload_optimizer_nvme_path",
27
+ type=str,
28
+ default=None,
29
+ help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
30
+ )
31
+ parser.add_argument(
32
+ "--offload_param_device",
33
+ type=str,
34
+ default=None,
35
+ choices=[None, "cpu", "nvme"],
36
+ help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.",
37
+ )
38
+ parser.add_argument(
39
+ "--offload_param_nvme_path",
40
+ type=str,
41
+ default=None,
42
+ help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
43
+ )
44
+ parser.add_argument(
45
+ "--zero3_init_flag",
46
+ action="store_true",
47
+ help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
48
+ "Only applicable with ZeRO Stage-3.",
49
+ )
50
+ parser.add_argument(
51
+ "--zero3_save_16bit_model",
52
+ action="store_true",
53
+ help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.",
54
+ )
55
+ parser.add_argument(
56
+ "--fp16_master_weights_and_gradients",
57
+ action="store_true",
58
+ help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.",
59
+ )
60
+
61
+
62
+ def prepare_deepspeed_args(args: argparse.Namespace):
63
+ if not args.deepspeed:
64
+ return
65
+
66
+ # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
67
+ args.max_data_loader_n_workers = 1
68
+
69
+
70
+ def prepare_deepspeed_plugin(args: argparse.Namespace):
71
+ if not args.deepspeed:
72
+ return None
73
+
74
+ try:
75
+ import deepspeed
76
+ except ImportError as e:
77
+ logger.error(
78
+ "deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed"
79
+ )
80
+ exit(1)
81
+
82
+ deepspeed_plugin = DeepSpeedPlugin(
83
+ zero_stage=args.zero_stage,
84
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
85
+ gradient_clipping=args.max_grad_norm,
86
+ offload_optimizer_device=args.offload_optimizer_device,
87
+ offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
88
+ offload_param_device=args.offload_param_device,
89
+ offload_param_nvme_path=args.offload_param_nvme_path,
90
+ zero3_init_flag=args.zero3_init_flag,
91
+ zero3_save_16bit_model=args.zero3_save_16bit_model,
92
+ )
93
+ deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
94
+ deepspeed_plugin.deepspeed_config["train_batch_size"] = 1#(
95
+ # args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
96
+ #)
97
+ deepspeed_plugin.set_mixed_precision(args.mixed_precision)
98
+ if args.mixed_precision.lower() == "fp16":
99
+ deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
100
+ if args.full_fp16 or args.fp16_master_weights_and_gradients:
101
+ if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
102
+ deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True
103
+ logger.info("[DeepSpeed] full fp16 enable.")
104
+ else:
105
+ logger.info(
106
+ "[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage."
107
+ )
108
+
109
+ if args.offload_optimizer_device is not None:
110
+ logger.info("[DeepSpeed] start to manually build cpu_adam.")
111
+ deepspeed.ops.op_builder.CPUAdamBuilder().load()
112
+ logger.info("[DeepSpeed] building cpu_adam done.")
113
+
114
+ return deepspeed_plugin
115
+
116
+
117
+ # Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.
118
+ def prepare_deepspeed_model(args: argparse.Namespace, **models):
119
+ # remove None from models
120
+ models = {k: v for k, v in models.items() if v is not None}
121
+
122
+ class DeepSpeedWrapper(torch.nn.Module):
123
+ def __init__(self, **kw_models) -> None:
124
+ super().__init__()
125
+ self.models = torch.nn.ModuleDict()
126
+
127
+ for key, model in kw_models.items():
128
+ if isinstance(model, list):
129
+ model = torch.nn.ModuleList(model)
130
+ assert isinstance(
131
+ model, torch.nn.Module
132
+ ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
133
+ self.models.update(torch.nn.ModuleDict({key: model}))
134
+
135
+ def get_models(self):
136
+ return self.models
137
+
138
+ ds_model = DeepSpeedWrapper(**models)
139
+ return ds_model
library/device_utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import gc
3
+
4
+ import torch
5
+
6
+ try:
7
+ HAS_CUDA = torch.cuda.is_available()
8
+ except Exception:
9
+ HAS_CUDA = False
10
+
11
+ try:
12
+ HAS_MPS = torch.backends.mps.is_available()
13
+ except Exception:
14
+ HAS_MPS = False
15
+
16
+ try:
17
+ import intel_extension_for_pytorch as ipex # noqa
18
+
19
+ HAS_XPU = torch.xpu.is_available()
20
+ except Exception:
21
+ HAS_XPU = False
22
+
23
+
24
+ def clean_memory():
25
+ gc.collect()
26
+ if HAS_CUDA:
27
+ torch.cuda.empty_cache()
28
+ if HAS_XPU:
29
+ torch.xpu.empty_cache()
30
+ if HAS_MPS:
31
+ torch.mps.empty_cache()
32
+
33
+
34
+ def clean_memory_on_device(device: torch.device):
35
+ r"""
36
+ Clean memory on the specified device, will be called from training scripts.
37
+ """
38
+ gc.collect()
39
+
40
+ # device may "cuda" or "cuda:0", so we need to check the type of device
41
+ if device.type == "cuda":
42
+ torch.cuda.empty_cache()
43
+ if device.type == "xpu":
44
+ torch.xpu.empty_cache()
45
+ if device.type == "mps":
46
+ torch.mps.empty_cache()
47
+
48
+
49
+ @functools.lru_cache(maxsize=None)
50
+ def get_preferred_device() -> torch.device:
51
+ r"""
52
+ Do not call this function from training scripts. Use accelerator.device instead.
53
+ """
54
+ if HAS_CUDA:
55
+ device = torch.device("cuda")
56
+ elif HAS_XPU:
57
+ device = torch.device("xpu")
58
+ elif HAS_MPS:
59
+ device = torch.device("mps")
60
+ else:
61
+ device = torch.device("cpu")
62
+ print(f"get_preferred_device() -> {device}")
63
+ return device
64
+
65
+
66
+ def init_ipex():
67
+ """
68
+ Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
69
+
70
+ This function should run right after importing torch and before doing anything else.
71
+
72
+ If IPEX is not available, this function does nothing.
73
+ """
74
+ try:
75
+ if HAS_XPU:
76
+ from .ipex import ipex_init
77
+
78
+ is_initialized, error_message = ipex_init()
79
+ if not is_initialized:
80
+ print("failed to initialize ipex:", error_message)
81
+ else:
82
+ return
83
+ except Exception as e:
84
+ print("failed to initialize ipex:", e)
library/flux_models.py ADDED
@@ -0,0 +1,1060 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copy from FLUX repo: https://github.com/black-forest-labs/flux
2
+ # license: Apache-2.0 License
3
+
4
+ from dataclasses import dataclass
5
+ import math
6
+ from typing import Dict, List, Optional, Union
7
+
8
+ from .device_utils import init_ipex
9
+ from .custom_offloading_utils import ModelOffloader
10
+ init_ipex()
11
+
12
+ import torch
13
+ from einops import rearrange
14
+ from torch import Tensor, nn
15
+ from torch.utils.checkpoint import checkpoint
16
+
17
+ # USE_REENTRANT = True
18
+
19
+
20
+ @dataclass
21
+ class FluxParams:
22
+ in_channels: int
23
+ vec_in_dim: int
24
+ context_in_dim: int
25
+ hidden_size: int
26
+ mlp_ratio: float
27
+ num_heads: int
28
+ depth: int
29
+ depth_single_blocks: int
30
+ axes_dim: list[int]
31
+ theta: int
32
+ qkv_bias: bool
33
+ guidance_embed: bool
34
+
35
+
36
+ # region autoencoder
37
+
38
+
39
+ @dataclass
40
+ class AutoEncoderParams:
41
+ resolution: int
42
+ in_channels: int
43
+ ch: int
44
+ out_ch: int
45
+ ch_mult: list[int]
46
+ num_res_blocks: int
47
+ z_channels: int
48
+ scale_factor: float
49
+ shift_factor: float
50
+
51
+
52
+ def swish(x: Tensor) -> Tensor:
53
+ return x * torch.sigmoid(x)
54
+
55
+
56
+ class AttnBlock(nn.Module):
57
+ def __init__(self, in_channels: int):
58
+ super().__init__()
59
+ self.in_channels = in_channels
60
+
61
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
62
+
63
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
64
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
65
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
66
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
67
+
68
+ def attention(self, h_: Tensor) -> Tensor:
69
+ h_ = self.norm(h_)
70
+ q = self.q(h_)
71
+ k = self.k(h_)
72
+ v = self.v(h_)
73
+
74
+ b, c, h, w = q.shape
75
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
76
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
77
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
78
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
79
+
80
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
81
+
82
+ def forward(self, x: Tensor) -> Tensor:
83
+ return x + self.proj_out(self.attention(x))
84
+
85
+
86
+ class ResnetBlock(nn.Module):
87
+ def __init__(self, in_channels: int, out_channels: int):
88
+ super().__init__()
89
+ self.in_channels = in_channels
90
+ out_channels = in_channels if out_channels is None else out_channels
91
+ self.out_channels = out_channels
92
+
93
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
94
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
95
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
96
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
97
+ if self.in_channels != self.out_channels:
98
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
99
+
100
+ def forward(self, x):
101
+ h = x
102
+ h = self.norm1(h)
103
+ h = swish(h)
104
+ h = self.conv1(h)
105
+
106
+ h = self.norm2(h)
107
+ h = swish(h)
108
+ h = self.conv2(h)
109
+
110
+ if self.in_channels != self.out_channels:
111
+ x = self.nin_shortcut(x)
112
+
113
+ return x + h
114
+
115
+
116
+ class Downsample(nn.Module):
117
+ def __init__(self, in_channels: int):
118
+ super().__init__()
119
+ # no asymmetric padding in torch conv, must do it ourselves
120
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
121
+
122
+ def forward(self, x: Tensor):
123
+ pad = (0, 1, 0, 1)
124
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
125
+ x = self.conv(x)
126
+ return x
127
+
128
+
129
+ class Upsample(nn.Module):
130
+ def __init__(self, in_channels: int):
131
+ super().__init__()
132
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
133
+
134
+ def forward(self, x: Tensor):
135
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
136
+ x = self.conv(x)
137
+ return x
138
+
139
+
140
+ class Encoder(nn.Module):
141
+ def __init__(
142
+ self,
143
+ resolution: int,
144
+ in_channels: int,
145
+ ch: int,
146
+ ch_mult: list[int],
147
+ num_res_blocks: int,
148
+ z_channels: int,
149
+ ):
150
+ super().__init__()
151
+ self.ch = ch
152
+ self.num_resolutions = len(ch_mult)
153
+ self.num_res_blocks = num_res_blocks
154
+ self.resolution = resolution
155
+ self.in_channels = in_channels
156
+ # downsampling
157
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
158
+
159
+ curr_res = resolution
160
+ in_ch_mult = (1,) + tuple(ch_mult)
161
+ self.in_ch_mult = in_ch_mult
162
+ self.down = nn.ModuleList()
163
+ block_in = self.ch
164
+ for i_level in range(self.num_resolutions):
165
+ block = nn.ModuleList()
166
+ attn = nn.ModuleList()
167
+ block_in = ch * in_ch_mult[i_level]
168
+ block_out = ch * ch_mult[i_level]
169
+ for _ in range(self.num_res_blocks):
170
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
171
+ block_in = block_out
172
+ down = nn.Module()
173
+ down.block = block
174
+ down.attn = attn
175
+ if i_level != self.num_resolutions - 1:
176
+ down.downsample = Downsample(block_in)
177
+ curr_res = curr_res // 2
178
+ self.down.append(down)
179
+
180
+ # middle
181
+ self.mid = nn.Module()
182
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
183
+ self.mid.attn_1 = AttnBlock(block_in)
184
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
185
+
186
+ # end
187
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
188
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
189
+
190
+ def forward(self, x: Tensor) -> Tensor:
191
+ # downsampling
192
+ hs = [self.conv_in(x)]
193
+ for i_level in range(self.num_resolutions):
194
+ for i_block in range(self.num_res_blocks):
195
+ h = self.down[i_level].block[i_block](hs[-1])
196
+ if len(self.down[i_level].attn) > 0:
197
+ h = self.down[i_level].attn[i_block](h)
198
+ hs.append(h)
199
+ if i_level != self.num_resolutions - 1:
200
+ hs.append(self.down[i_level].downsample(hs[-1]))
201
+
202
+ # middle
203
+ h = hs[-1]
204
+ h = self.mid.block_1(h)
205
+ h = self.mid.attn_1(h)
206
+ h = self.mid.block_2(h)
207
+ # end
208
+ h = self.norm_out(h)
209
+ h = swish(h)
210
+ h = self.conv_out(h)
211
+ return h
212
+
213
+
214
+ class Decoder(nn.Module):
215
+ def __init__(
216
+ self,
217
+ ch: int,
218
+ out_ch: int,
219
+ ch_mult: list[int],
220
+ num_res_blocks: int,
221
+ in_channels: int,
222
+ resolution: int,
223
+ z_channels: int,
224
+ ):
225
+ super().__init__()
226
+ self.ch = ch
227
+ self.num_resolutions = len(ch_mult)
228
+ self.num_res_blocks = num_res_blocks
229
+ self.resolution = resolution
230
+ self.in_channels = in_channels
231
+ self.ffactor = 2 ** (self.num_resolutions - 1)
232
+
233
+ # compute in_ch_mult, block_in and curr_res at lowest res
234
+ block_in = ch * ch_mult[self.num_resolutions - 1]
235
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
236
+ self.z_shape = (1, z_channels, curr_res, curr_res)
237
+
238
+ # z to block_in
239
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
240
+
241
+ # middle
242
+ self.mid = nn.Module()
243
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
244
+ self.mid.attn_1 = AttnBlock(block_in)
245
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
246
+
247
+ # upsampling
248
+ self.up = nn.ModuleList()
249
+ for i_level in reversed(range(self.num_resolutions)):
250
+ block = nn.ModuleList()
251
+ attn = nn.ModuleList()
252
+ block_out = ch * ch_mult[i_level]
253
+ for _ in range(self.num_res_blocks + 1):
254
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
255
+ block_in = block_out
256
+ up = nn.Module()
257
+ up.block = block
258
+ up.attn = attn
259
+ if i_level != 0:
260
+ up.upsample = Upsample(block_in)
261
+ curr_res = curr_res * 2
262
+ self.up.insert(0, up) # prepend to get consistent order
263
+
264
+ # end
265
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
266
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
267
+
268
+ def forward(self, z: Tensor) -> Tensor:
269
+ # z to block_in
270
+ h = self.conv_in(z)
271
+
272
+ # middle
273
+ h = self.mid.block_1(h)
274
+ h = self.mid.attn_1(h)
275
+ h = self.mid.block_2(h)
276
+
277
+ # upsampling
278
+ for i_level in reversed(range(self.num_resolutions)):
279
+ for i_block in range(self.num_res_blocks + 1):
280
+ h = self.up[i_level].block[i_block](h)
281
+ if len(self.up[i_level].attn) > 0:
282
+ h = self.up[i_level].attn[i_block](h)
283
+ if i_level != 0:
284
+ h = self.up[i_level].upsample(h)
285
+
286
+ # end
287
+ h = self.norm_out(h)
288
+ h = swish(h)
289
+ h = self.conv_out(h)
290
+ return h
291
+
292
+
293
+ class DiagonalGaussian(nn.Module):
294
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
295
+ super().__init__()
296
+ self.sample = sample
297
+ self.chunk_dim = chunk_dim
298
+
299
+ def forward(self, z: Tensor) -> Tensor:
300
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
301
+ if self.sample:
302
+ std = torch.exp(0.5 * logvar)
303
+ return mean + std * torch.randn_like(mean)
304
+ else:
305
+ return mean
306
+
307
+
308
+ class AutoEncoder(nn.Module):
309
+ def __init__(self, params: AutoEncoderParams):
310
+ super().__init__()
311
+ self.encoder = Encoder(
312
+ resolution=params.resolution,
313
+ in_channels=params.in_channels,
314
+ ch=params.ch,
315
+ ch_mult=params.ch_mult,
316
+ num_res_blocks=params.num_res_blocks,
317
+ z_channels=params.z_channels,
318
+ )
319
+ self.decoder = Decoder(
320
+ resolution=params.resolution,
321
+ in_channels=params.in_channels,
322
+ ch=params.ch,
323
+ out_ch=params.out_ch,
324
+ ch_mult=params.ch_mult,
325
+ num_res_blocks=params.num_res_blocks,
326
+ z_channels=params.z_channels,
327
+ )
328
+ self.reg = DiagonalGaussian()
329
+
330
+ self.scale_factor = params.scale_factor
331
+ self.shift_factor = params.shift_factor
332
+
333
+ @property
334
+ def device(self) -> torch.device:
335
+ return next(self.parameters()).device
336
+
337
+ @property
338
+ def dtype(self) -> torch.dtype:
339
+ return next(self.parameters()).dtype
340
+
341
+ def encode(self, x: Tensor) -> Tensor:
342
+ z = self.reg(self.encoder(x))
343
+ z = self.scale_factor * (z - self.shift_factor)
344
+ return z
345
+
346
+ def decode(self, z: Tensor) -> Tensor:
347
+ z = z / self.scale_factor + self.shift_factor
348
+ return self.decoder(z)
349
+
350
+ def forward(self, x: Tensor) -> Tensor:
351
+ return self.decode(self.encode(x))
352
+
353
+
354
+ # endregion
355
+ # region config
356
+
357
+
358
+ @dataclass
359
+ class ModelSpec:
360
+ params: FluxParams
361
+ ae_params: AutoEncoderParams
362
+ ckpt_path: str | None
363
+ ae_path: str | None
364
+ # repo_id: str | None
365
+ # repo_flow: str | None
366
+ # repo_ae: str | None
367
+
368
+
369
+ configs = {
370
+ "dev": ModelSpec(
371
+ # repo_id="black-forest-labs/FLUX.1-dev",
372
+ # repo_flow="flux1-dev.sft",
373
+ # repo_ae="ae.sft",
374
+ ckpt_path=None, # os.getenv("FLUX_DEV"),
375
+ params=FluxParams(
376
+ in_channels=64,
377
+ vec_in_dim=768,
378
+ context_in_dim=4096,
379
+ hidden_size=3072,
380
+ mlp_ratio=4.0,
381
+ num_heads=24,
382
+ depth=19,
383
+ depth_single_blocks=38,
384
+ axes_dim=[16, 56, 56],
385
+ theta=10_000,
386
+ qkv_bias=True,
387
+ guidance_embed=True,
388
+ ),
389
+ ae_path=None, # os.getenv("AE"),
390
+ ae_params=AutoEncoderParams(
391
+ resolution=256,
392
+ in_channels=3,
393
+ ch=128,
394
+ out_ch=3,
395
+ ch_mult=[1, 2, 4, 4],
396
+ num_res_blocks=2,
397
+ z_channels=16,
398
+ scale_factor=0.3611,
399
+ shift_factor=0.1159,
400
+ ),
401
+ ),
402
+ "schnell": ModelSpec(
403
+ # repo_id="black-forest-labs/FLUX.1-schnell",
404
+ # repo_flow="flux1-schnell.sft",
405
+ # repo_ae="ae.sft",
406
+ ckpt_path=None, # os.getenv("FLUX_SCHNELL"),
407
+ params=FluxParams(
408
+ in_channels=64,
409
+ vec_in_dim=768,
410
+ context_in_dim=4096,
411
+ hidden_size=3072,
412
+ mlp_ratio=4.0,
413
+ num_heads=24,
414
+ depth=19,
415
+ depth_single_blocks=38,
416
+ axes_dim=[16, 56, 56],
417
+ theta=10_000,
418
+ qkv_bias=True,
419
+ guidance_embed=False,
420
+ ),
421
+ ae_path=None, # os.getenv("AE"),
422
+ ae_params=AutoEncoderParams(
423
+ resolution=256,
424
+ in_channels=3,
425
+ ch=128,
426
+ out_ch=3,
427
+ ch_mult=[1, 2, 4, 4],
428
+ num_res_blocks=2,
429
+ z_channels=16,
430
+ scale_factor=0.3611,
431
+ shift_factor=0.1159,
432
+ ),
433
+ ),
434
+ }
435
+
436
+
437
+ # endregion
438
+
439
+ # region math
440
+
441
+
442
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
443
+ q, k = apply_rope(q, k, pe)
444
+
445
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
446
+ x = rearrange(x, "B H L D -> B L (H D)")
447
+
448
+ return x
449
+
450
+
451
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
452
+ assert dim % 2 == 0
453
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
454
+ omega = 1.0 / (theta**scale)
455
+ out = torch.einsum("...n,d->...nd", pos, omega)
456
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
457
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
458
+ return out.float()
459
+
460
+
461
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
462
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
463
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
464
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
465
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
466
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
467
+
468
+
469
+ # endregion
470
+
471
+
472
+ # region layers
473
+
474
+
475
+ # for cpu_offload_checkpointing
476
+
477
+
478
+ def to_cuda(x):
479
+ if isinstance(x, torch.Tensor):
480
+ return x.cuda()
481
+ elif isinstance(x, (list, tuple)):
482
+ return [to_cuda(elem) for elem in x]
483
+ elif isinstance(x, dict):
484
+ return {k: to_cuda(v) for k, v in x.items()}
485
+ else:
486
+ return x
487
+
488
+
489
+ def to_cpu(x):
490
+ if isinstance(x, torch.Tensor):
491
+ return x.cpu()
492
+ elif isinstance(x, (list, tuple)):
493
+ return [to_cpu(elem) for elem in x]
494
+ elif isinstance(x, dict):
495
+ return {k: to_cpu(v) for k, v in x.items()}
496
+ else:
497
+ return x
498
+
499
+
500
+ class EmbedND(nn.Module):
501
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
502
+ super().__init__()
503
+ self.dim = dim
504
+ self.theta = theta
505
+ self.axes_dim = axes_dim
506
+
507
+ def forward(self, ids: Tensor) -> Tensor:
508
+ n_axes = ids.shape[-1]
509
+ emb = torch.cat(
510
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
511
+ dim=-3,
512
+ )
513
+
514
+ return emb.unsqueeze(1)
515
+
516
+
517
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
518
+ """
519
+ Create sinusoidal timestep embeddings.
520
+ :param t: a 1-D Tensor of N indices, one per batch element.
521
+ These may be fractional.
522
+ :param dim: the dimension of the output.
523
+ :param max_period: controls the minimum frequency of the embeddings.
524
+ :return: an (N, D) Tensor of positional embeddings.
525
+ """
526
+ t = time_factor * t
527
+ half = dim // 2
528
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
529
+
530
+ args = t[:, None].float() * freqs[None]
531
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
532
+ if dim % 2:
533
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
534
+ if torch.is_floating_point(t):
535
+ embedding = embedding.to(t)
536
+ return embedding
537
+
538
+
539
+ class MLPEmbedder(nn.Module):
540
+ def __init__(self, in_dim: int, hidden_dim: int):
541
+ super().__init__()
542
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
543
+ self.silu = nn.SiLU()
544
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
545
+
546
+ self.gradient_checkpointing = False
547
+
548
+ def enable_gradient_checkpointing(self):
549
+ self.gradient_checkpointing = True
550
+
551
+ def disable_gradient_checkpointing(self):
552
+ self.gradient_checkpointing = False
553
+
554
+ def _forward(self, x: Tensor) -> Tensor:
555
+ return self.out_layer(self.silu(self.in_layer(x)))
556
+
557
+ def forward(self, *args, **kwargs):
558
+ if self.training and self.gradient_checkpointing:
559
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
560
+ else:
561
+ return self._forward(*args, **kwargs)
562
+
563
+ # def forward(self, x):
564
+ # if self.training and self.gradient_checkpointing:
565
+ # def create_custom_forward(func):
566
+ # def custom_forward(*inputs):
567
+ # return func(*inputs)
568
+ # return custom_forward
569
+ # return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, use_reentrant=USE_REENTRANT)
570
+ # else:
571
+ # return self._forward(x)
572
+
573
+
574
+ class RMSNorm(torch.nn.Module):
575
+ def __init__(self, dim: int):
576
+ super().__init__()
577
+ self.scale = nn.Parameter(torch.ones(dim))
578
+
579
+ def forward(self, x: Tensor):
580
+ x_dtype = x.dtype
581
+ x = x.float()
582
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
583
+ # return (x * rrms).to(dtype=x_dtype) * self.scale
584
+ return ((x * rrms) * self.scale.float()).to(dtype=x_dtype)
585
+
586
+
587
+ class QKNorm(torch.nn.Module):
588
+ def __init__(self, dim: int):
589
+ super().__init__()
590
+ self.query_norm = RMSNorm(dim)
591
+ self.key_norm = RMSNorm(dim)
592
+
593
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
594
+ q = self.query_norm(q)
595
+ k = self.key_norm(k)
596
+ return q.to(v), k.to(v)
597
+
598
+
599
+ class SelfAttention(nn.Module):
600
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
601
+ super().__init__()
602
+ self.num_heads = num_heads
603
+ head_dim = dim // num_heads
604
+
605
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
606
+ self.norm = QKNorm(head_dim)
607
+ self.proj = nn.Linear(dim, dim)
608
+
609
+ # this is not called from DoubleStreamBlock/SingleStreamBlock because they uses attention function directly
610
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
611
+ qkv = self.qkv(x)
612
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
613
+ q, k = self.norm(q, k, v)
614
+ x = attention(q, k, v, pe=pe)
615
+ x = self.proj(x)
616
+ return x
617
+
618
+
619
+ @dataclass
620
+ class ModulationOut:
621
+ shift: Tensor
622
+ scale: Tensor
623
+ gate: Tensor
624
+
625
+
626
+ class Modulation(nn.Module):
627
+ def __init__(self, dim: int, double: bool):
628
+ super().__init__()
629
+ self.is_double = double
630
+ self.multiplier = 6 if double else 3
631
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
632
+
633
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
634
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
635
+
636
+ return (
637
+ ModulationOut(*out[:3]),
638
+ ModulationOut(*out[3:]) if self.is_double else None,
639
+ )
640
+
641
+
642
+ class DoubleStreamBlock(nn.Module):
643
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
644
+ super().__init__()
645
+
646
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
647
+ self.num_heads = num_heads
648
+ self.hidden_size = hidden_size
649
+ self.img_mod = Modulation(hidden_size, double=True)
650
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
651
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
652
+
653
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
654
+ self.img_mlp = nn.Sequential(
655
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
656
+ nn.GELU(approximate="tanh"),
657
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
658
+ )
659
+
660
+ self.txt_mod = Modulation(hidden_size, double=True)
661
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
662
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
663
+
664
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
665
+ self.txt_mlp = nn.Sequential(
666
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
667
+ nn.GELU(approximate="tanh"),
668
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
669
+ )
670
+
671
+ self.gradient_checkpointing = False
672
+ self.cpu_offload_checkpointing = False
673
+
674
+ def enable_gradient_checkpointing(self, cpu_offload: bool = False):
675
+ self.gradient_checkpointing = True
676
+ self.cpu_offload_checkpointing = cpu_offload
677
+
678
+ def disable_gradient_checkpointing(self):
679
+ self.gradient_checkpointing = False
680
+ self.cpu_offload_checkpointing = False
681
+
682
+ def _forward(
683
+ self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None
684
+ ) -> tuple[Tensor, Tensor]:
685
+ img_mod1, img_mod2 = self.img_mod(vec)
686
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
687
+
688
+ # prepare image for attention
689
+ img_modulated = self.img_norm1(img)
690
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
691
+ img_qkv = self.img_attn.qkv(img_modulated)
692
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
693
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
694
+
695
+ # prepare txt for attention
696
+ txt_modulated = self.txt_norm1(txt)
697
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
698
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
699
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
700
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
701
+
702
+ # run actual attention
703
+ q = torch.cat((txt_q, img_q), dim=2)
704
+ k = torch.cat((txt_k, img_k), dim=2)
705
+ v = torch.cat((txt_v, img_v), dim=2)
706
+
707
+ # make attention mask if not None
708
+ attn_mask = None
709
+ if txt_attention_mask is not None:
710
+ # F.scaled_dot_product_attention expects attn_mask to be bool for binary mask
711
+ attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len
712
+ attn_mask = torch.cat(
713
+ (attn_mask, torch.ones(attn_mask.shape[0], img.shape[1], device=attn_mask.device, dtype=torch.bool)), dim=1
714
+ ) # b, seq_len + img_len
715
+
716
+ # broadcast attn_mask to all heads
717
+ attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
718
+
719
+ attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
720
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
721
+
722
+ # calculate the img blocks
723
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
724
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
725
+
726
+ # calculate the txt blocks
727
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
728
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
729
+ return img, txt
730
+
731
+ def forward(
732
+ self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None
733
+ ) -> tuple[Tensor, Tensor]:
734
+ if self.training and self.gradient_checkpointing:
735
+ if not self.cpu_offload_checkpointing:
736
+ return checkpoint(self._forward, img, txt, vec, pe, txt_attention_mask, use_reentrant=False)
737
+ # cpu offload checkpointing
738
+
739
+ def create_custom_forward(func):
740
+ def custom_forward(*inputs):
741
+ cuda_inputs = to_cuda(inputs)
742
+ outputs = func(*cuda_inputs)
743
+ return to_cpu(outputs)
744
+
745
+ return custom_forward
746
+
747
+ return torch.utils.checkpoint.checkpoint(
748
+ create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask, use_reentrant=False
749
+ )
750
+
751
+ else:
752
+ return self._forward(img, txt, vec, pe, txt_attention_mask)
753
+
754
+
755
+ class SingleStreamBlock(nn.Module):
756
+ """
757
+ A DiT block with parallel linear layers as described in
758
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
759
+ """
760
+
761
+ def __init__(
762
+ self,
763
+ hidden_size: int,
764
+ num_heads: int,
765
+ mlp_ratio: float = 4.0,
766
+ qk_scale: float | None = None,
767
+ ):
768
+ super().__init__()
769
+ self.hidden_dim = hidden_size
770
+ self.num_heads = num_heads
771
+ head_dim = hidden_size // num_heads
772
+ self.scale = qk_scale or head_dim**-0.5
773
+
774
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
775
+ # qkv and mlp_in
776
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
777
+ # proj and mlp_out
778
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
779
+
780
+ self.norm = QKNorm(head_dim)
781
+
782
+ self.hidden_size = hidden_size
783
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
784
+
785
+ self.mlp_act = nn.GELU(approximate="tanh")
786
+ self.modulation = Modulation(hidden_size, double=False)
787
+
788
+ self.gradient_checkpointing = False
789
+ self.cpu_offload_checkpointing = False
790
+
791
+ def enable_gradient_checkpointing(self, cpu_offload: bool = False):
792
+ self.gradient_checkpointing = True
793
+ self.cpu_offload_checkpointing = cpu_offload
794
+
795
+ def disable_gradient_checkpointing(self):
796
+ self.gradient_checkpointing = False
797
+ self.cpu_offload_checkpointing = False
798
+
799
+ def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
800
+ mod, _ = self.modulation(vec)
801
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
802
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
803
+
804
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
805
+ q, k = self.norm(q, k, v)
806
+
807
+ # make attention mask if not None
808
+ attn_mask = None
809
+ if txt_attention_mask is not None:
810
+ # F.scaled_dot_product_attention expects attn_mask to be bool for binary mask
811
+ attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len
812
+ attn_mask = torch.cat(
813
+ (
814
+ attn_mask,
815
+ torch.ones(
816
+ attn_mask.shape[0], x.shape[1] - txt_attention_mask.shape[1], device=attn_mask.device, dtype=torch.bool
817
+ ),
818
+ ),
819
+ dim=1,
820
+ ) # b, seq_len + img_len = x_len
821
+
822
+ # broadcast attn_mask to all heads
823
+ attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
824
+
825
+ # compute attention
826
+ attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
827
+
828
+ # compute activation in mlp stream, cat again and run second linear layer
829
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
830
+ return x + mod.gate * output
831
+
832
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
833
+ if self.training and self.gradient_checkpointing:
834
+ if not self.cpu_offload_checkpointing:
835
+ return checkpoint(self._forward, x, vec, pe, txt_attention_mask, use_reentrant=False)
836
+
837
+ # cpu offload checkpointing
838
+
839
+ def create_custom_forward(func):
840
+ def custom_forward(*inputs):
841
+ cuda_inputs = to_cuda(inputs)
842
+ outputs = func(*cuda_inputs)
843
+ return to_cpu(outputs)
844
+
845
+ return custom_forward
846
+
847
+ return torch.utils.checkpoint.checkpoint(
848
+ create_custom_forward(self._forward), x, vec, pe, txt_attention_mask, use_reentrant=False
849
+ )
850
+ else:
851
+ return self._forward(x, vec, pe, txt_attention_mask)
852
+
853
+
854
+ class LastLayer(nn.Module):
855
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
856
+ super().__init__()
857
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
858
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
859
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
860
+
861
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
862
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
863
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
864
+ x = self.linear(x)
865
+ return x
866
+
867
+
868
+ # endregion
869
+
870
+
871
+ class Flux(nn.Module):
872
+ """
873
+ Transformer model for flow matching on sequences.
874
+ """
875
+
876
+ def __init__(self, params: FluxParams):
877
+ super().__init__()
878
+
879
+ self.params = params
880
+ self.in_channels = params.in_channels
881
+ self.out_channels = self.in_channels
882
+ if params.hidden_size % params.num_heads != 0:
883
+ raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
884
+ pe_dim = params.hidden_size // params.num_heads
885
+ if sum(params.axes_dim) != pe_dim:
886
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
887
+ self.hidden_size = params.hidden_size
888
+ self.num_heads = params.num_heads
889
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
890
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
891
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
892
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
893
+ self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
894
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
895
+
896
+ self.double_blocks = nn.ModuleList(
897
+ [
898
+ DoubleStreamBlock(
899
+ self.hidden_size,
900
+ self.num_heads,
901
+ mlp_ratio=params.mlp_ratio,
902
+ qkv_bias=params.qkv_bias,
903
+ )
904
+ for _ in range(params.depth)
905
+ ]
906
+ )
907
+
908
+ self.single_blocks = nn.ModuleList(
909
+ [
910
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
911
+ for _ in range(params.depth_single_blocks)
912
+ ]
913
+ )
914
+
915
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
916
+
917
+ self.gradient_checkpointing = False
918
+ self.cpu_offload_checkpointing = False
919
+ self.blocks_to_swap = None
920
+
921
+ self.offloader_double = None
922
+ self.offloader_single = None
923
+ self.num_double_blocks = len(self.double_blocks)
924
+ self.num_single_blocks = len(self.single_blocks)
925
+
926
+ @property
927
+ def device(self):
928
+ return next(self.parameters()).device
929
+
930
+ @property
931
+ def dtype(self):
932
+ return next(self.parameters()).dtype
933
+
934
+ def enable_gradient_checkpointing(self, cpu_offload: bool = False):
935
+ self.gradient_checkpointing = True
936
+ self.cpu_offload_checkpointing = cpu_offload
937
+
938
+ self.time_in.enable_gradient_checkpointing()
939
+ self.vector_in.enable_gradient_checkpointing()
940
+ if self.guidance_in.__class__ != nn.Identity:
941
+ self.guidance_in.enable_gradient_checkpointing()
942
+
943
+ for block in self.double_blocks + self.single_blocks:
944
+ block.enable_gradient_checkpointing(cpu_offload=cpu_offload)
945
+
946
+ print(f"FLUX: Gradient checkpointing enabled. CPU offload: {cpu_offload}")
947
+
948
+ def disable_gradient_checkpointing(self):
949
+ self.gradient_checkpointing = False
950
+ self.cpu_offload_checkpointing = False
951
+
952
+ self.time_in.disable_gradient_checkpointing()
953
+ self.vector_in.disable_gradient_checkpointing()
954
+ if self.guidance_in.__class__ != nn.Identity:
955
+ self.guidance_in.disable_gradient_checkpointing()
956
+
957
+ for block in self.double_blocks + self.single_blocks:
958
+ block.disable_gradient_checkpointing()
959
+
960
+ print("FLUX: Gradient checkpointing disabled.")
961
+
962
+ def enable_block_swap(self, num_blocks: int, device: torch.device):
963
+ self.blocks_to_swap = num_blocks
964
+ double_blocks_to_swap = num_blocks // 2
965
+ single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
966
+
967
+ assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, (
968
+ f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. "
969
+ f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
970
+ )
971
+
972
+ self.offloader_double = ModelOffloader(
973
+ self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
974
+ )
975
+ self.offloader_single = ModelOffloader(
976
+ self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
977
+ )
978
+ print(
979
+ f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
980
+ )
981
+
982
+ def move_to_device_except_swap_blocks(self, device: torch.device):
983
+ # assume model is on cpu. do not move blocks to device to reduce temporary memory usage
984
+ if self.blocks_to_swap:
985
+ save_double_blocks = self.double_blocks
986
+ save_single_blocks = self.single_blocks
987
+ self.double_blocks = None
988
+ self.single_blocks = None
989
+
990
+ self.to(device)
991
+
992
+ if self.blocks_to_swap:
993
+ self.double_blocks = save_double_blocks
994
+ self.single_blocks = save_single_blocks
995
+
996
+ def prepare_block_swap_before_forward(self):
997
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
998
+ return
999
+ self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
1000
+ self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
1001
+
1002
+ def forward(
1003
+ self,
1004
+ img: Tensor,
1005
+ img_ids: Tensor,
1006
+ txt: Tensor,
1007
+ txt_ids: Tensor,
1008
+ timesteps: Tensor,
1009
+ y: Tensor,
1010
+ guidance: Tensor | None = None,
1011
+ txt_attention_mask: Tensor | None = None,
1012
+ ) -> Tensor:
1013
+ if img.ndim != 3 or txt.ndim != 3:
1014
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
1015
+
1016
+ # running on sequences img
1017
+ img = self.img_in(img)
1018
+ vec = self.time_in(timestep_embedding(timesteps, 256))
1019
+ if self.params.guidance_embed:
1020
+ if guidance is None:
1021
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
1022
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
1023
+ vec = vec + self.vector_in(y)
1024
+ txt = self.txt_in(txt)
1025
+
1026
+ ids = torch.cat((txt_ids, img_ids), dim=1)
1027
+ pe = self.pe_embedder(ids)
1028
+
1029
+ if not self.blocks_to_swap:
1030
+ for block in self.double_blocks:
1031
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
1032
+ img = torch.cat((txt, img), 1)
1033
+ for block in self.single_blocks:
1034
+ img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
1035
+ else:
1036
+ for block_idx, block in enumerate(self.double_blocks):
1037
+ self.offloader_double.wait_for_block(block_idx)
1038
+
1039
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
1040
+
1041
+ self.offloader_double.submit_move_blocks(self.double_blocks, block_idx)
1042
+
1043
+ img = torch.cat((txt, img), 1)
1044
+
1045
+ for block_idx, block in enumerate(self.single_blocks):
1046
+ self.offloader_single.wait_for_block(block_idx)
1047
+
1048
+ img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
1049
+
1050
+ self.offloader_single.submit_move_blocks(self.single_blocks, block_idx)
1051
+
1052
+ img = img[:, txt.shape[1] :, ...]
1053
+
1054
+ if self.training and self.cpu_offload_checkpointing:
1055
+ img = img.to(self.device)
1056
+ vec = vec.to(self.device)
1057
+
1058
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
1059
+
1060
+ return img
library/flux_train_utils.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ import numpy as np
5
+ import toml
6
+ import json
7
+ import time
8
+ from typing import Callable, Dict, List, Optional, Tuple, Union
9
+
10
+ import torch
11
+ from accelerate import Accelerator, PartialState
12
+ from transformers import CLIPTextModel
13
+ from tqdm import tqdm
14
+ from PIL import Image
15
+
16
+ from safetensors.torch import save_file
17
+ from . import flux_models, flux_utils, strategy_base, train_util
18
+ from .device_utils import init_ipex, clean_memory_on_device
19
+
20
+ init_ipex()
21
+
22
+ from .utils import setup_logging, mem_eff_save_file
23
+
24
+ setup_logging()
25
+ import logging
26
+
27
+ logger = logging.getLogger(__name__)
28
+ # from comfy.utils import ProgressBar
29
+
30
+ def sample_images(
31
+ accelerator: Accelerator,
32
+ args: argparse.Namespace,
33
+ epoch,
34
+ steps,
35
+ flux,
36
+ ae,
37
+ text_encoders,
38
+ sample_prompts_te_outputs,
39
+ validation_settings=None,
40
+ prompt_replacement=None,
41
+ ):
42
+
43
+ logger.info("")
44
+ logger.info(f"generating sample images at step: {steps}")
45
+
46
+ #distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
47
+
48
+ # unwrap unet and text_encoder(s)
49
+ flux = accelerator.unwrap_model(flux)
50
+ if text_encoders is not None:
51
+ text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
52
+ # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
53
+
54
+ prompts = []
55
+ for line in args.sample_prompts:
56
+ line = line.strip()
57
+ if len(line) > 0 and line[0] != "#":
58
+ prompts.append(line)
59
+
60
+ # preprocess prompts
61
+ for i in range(len(prompts)):
62
+ prompt_dict = prompts[i]
63
+ if isinstance(prompt_dict, str):
64
+ from .train_util import line_to_prompt_dict
65
+
66
+ prompt_dict = line_to_prompt_dict(prompt_dict)
67
+ prompts[i] = prompt_dict
68
+ assert isinstance(prompt_dict, dict)
69
+
70
+ # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
71
+ prompt_dict["enum"] = i
72
+ prompt_dict.pop("subset", None)
73
+
74
+ save_dir = args.output_dir + "/sample"
75
+ os.makedirs(save_dir, exist_ok=True)
76
+
77
+ # save random state to restore later
78
+ rng_state = torch.get_rng_state()
79
+ cuda_rng_state = None
80
+ try:
81
+ cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
82
+ except Exception:
83
+ pass
84
+
85
+ with torch.no_grad(), accelerator.autocast():
86
+ image_tensor_list = []
87
+ for prompt_dict in prompts:
88
+ image_tensor = sample_image_inference(
89
+ accelerator,
90
+ args,
91
+ flux,
92
+ text_encoders,
93
+ ae,
94
+ save_dir,
95
+ prompt_dict,
96
+ epoch,
97
+ steps,
98
+ sample_prompts_te_outputs,
99
+ prompt_replacement,
100
+ validation_settings
101
+ )
102
+ image_tensor_list.append(image_tensor)
103
+
104
+ torch.set_rng_state(rng_state)
105
+ if cuda_rng_state is not None:
106
+ torch.cuda.set_rng_state(cuda_rng_state)
107
+
108
+ clean_memory_on_device(accelerator.device)
109
+ return torch.cat(image_tensor_list, dim=0)
110
+
111
+
112
+ def sample_image_inference(
113
+ accelerator: Accelerator,
114
+ args: argparse.Namespace,
115
+ flux: flux_models.Flux,
116
+ text_encoders: Optional[List[CLIPTextModel]],
117
+ ae: flux_models.AutoEncoder,
118
+ save_dir,
119
+ prompt_dict,
120
+ epoch,
121
+ steps,
122
+ sample_prompts_te_outputs,
123
+ prompt_replacement,
124
+ validation_settings=None
125
+ ):
126
+ assert isinstance(prompt_dict, dict)
127
+ # negative_prompt = prompt_dict.get("negative_prompt")
128
+ if validation_settings is not None:
129
+ sample_steps = validation_settings["steps"]
130
+ width = validation_settings["width"]
131
+ height = validation_settings["height"]
132
+ scale = validation_settings["guidance_scale"]
133
+ seed = validation_settings["seed"]
134
+ base_shift = validation_settings["base_shift"]
135
+ max_shift = validation_settings["max_shift"]
136
+ shift = validation_settings["shift"]
137
+ else:
138
+ sample_steps = prompt_dict.get("sample_steps", 20)
139
+ width = prompt_dict.get("width", 512)
140
+ height = prompt_dict.get("height", 512)
141
+ scale = prompt_dict.get("scale", 3.5)
142
+ seed = prompt_dict.get("seed")
143
+ base_shift = 0.5
144
+ max_shift = 1.15
145
+ shift = True
146
+ # controlnet_image = prompt_dict.get("controlnet_image")
147
+ prompt: str = prompt_dict.get("prompt", "")
148
+ # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
149
+
150
+ if prompt_replacement is not None:
151
+ prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
152
+ # if negative_prompt is not None:
153
+ # negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
154
+
155
+ if seed is not None:
156
+ torch.manual_seed(seed)
157
+ torch.cuda.manual_seed(seed)
158
+ else:
159
+ # True random sample image generation
160
+ torch.seed()
161
+ torch.cuda.seed()
162
+
163
+ # if negative_prompt is None:
164
+ # negative_prompt = ""
165
+
166
+ height = max(64, height - height % 16) # round to divisible by 16
167
+ width = max(64, width - width % 16) # round to divisible by 16
168
+ logger.info(f"prompt: {prompt}")
169
+ # logger.info(f"negative_prompt: {negative_prompt}")
170
+ logger.info(f"height: {height}")
171
+ logger.info(f"width: {width}")
172
+ logger.info(f"sample_steps: {sample_steps}")
173
+ logger.info(f"scale: {scale}")
174
+ # logger.info(f"sample_sampler: {sampler_name}")
175
+ if seed is not None:
176
+ logger.info(f"seed: {seed}")
177
+
178
+ # encode prompts
179
+ tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
180
+ encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
181
+
182
+ text_encoder_conds = []
183
+ if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
184
+ text_encoder_conds = sample_prompts_te_outputs[prompt]
185
+ print(f"Using cached text encoder outputs for prompt: {prompt}")
186
+ if text_encoders is not None:
187
+ print(f"Encoding prompt: {prompt}")
188
+ tokens_and_masks = tokenize_strategy.tokenize(prompt)
189
+ # strategy has apply_t5_attn_mask option
190
+ encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
191
+
192
+ # if text_encoder_conds is not cached, use encoded_text_encoder_conds
193
+ if len(text_encoder_conds) == 0:
194
+ text_encoder_conds = encoded_text_encoder_conds
195
+ else:
196
+ # if encoded_text_encoder_conds is not None, update cached text_encoder_conds
197
+ for i in range(len(encoded_text_encoder_conds)):
198
+ if encoded_text_encoder_conds[i] is not None:
199
+ text_encoder_conds[i] = encoded_text_encoder_conds[i]
200
+
201
+ l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
202
+
203
+ # sample image
204
+ weight_dtype = ae.dtype # TOFO give dtype as argument
205
+ packed_latent_height = height // 16
206
+ packed_latent_width = width // 16
207
+ noise = torch.randn(
208
+ 1,
209
+ packed_latent_height * packed_latent_width,
210
+ 16 * 2 * 2,
211
+ device=accelerator.device,
212
+ dtype=weight_dtype,
213
+ generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
214
+ )
215
+ timesteps = get_schedule(sample_steps, noise.shape[1], base_shift=base_shift, max_shift=max_shift, shift=shift) # FLUX.1 dev -> shift=True
216
+ #print("TIMESTEPS: ", timesteps)
217
+ img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
218
+ t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
219
+
220
+ with accelerator.autocast(), torch.no_grad():
221
+ x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask)
222
+
223
+ x = x.float()
224
+ x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
225
+
226
+ # latent to image
227
+ clean_memory_on_device(accelerator.device)
228
+ org_vae_device = ae.device # will be on cpu
229
+ ae.to(accelerator.device) # distributed_state.device is same as accelerator.device
230
+ with accelerator.autocast(), torch.no_grad():
231
+ x = ae.decode(x)
232
+ ae.to(org_vae_device)
233
+ clean_memory_on_device(accelerator.device)
234
+
235
+ x = x.clamp(-1, 1)
236
+ x = x.permute(0, 2, 3, 1)
237
+ image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
238
+
239
+ # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
240
+ # but adding 'enum' to the filename should be enough
241
+
242
+ ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
243
+ num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
244
+ seed_suffix = "" if seed is None else f"_{seed}"
245
+ i: int = prompt_dict["enum"]
246
+ img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
247
+ image.save(os.path.join(save_dir, img_filename))
248
+ return x
249
+
250
+ # wandb有効時のみログを送信
251
+ # try:
252
+ # wandb_tracker = accelerator.get_tracker("wandb")
253
+ # try:
254
+ # import wandb
255
+ # except ImportError: # 事前に一度確認するのでここはエラー出ないはず
256
+ # raise ImportError("No wandb / wandb がインストールされていないようです")
257
+
258
+ # wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
259
+ # except: # wandb 無効時
260
+ # pass
261
+
262
+
263
+ def time_shift(mu: float, sigma: float, t: torch.Tensor):
264
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
265
+
266
+
267
+ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
268
+ m = (y2 - y1) / (x2 - x1)
269
+ b = y1 - m * x1
270
+ return lambda x: m * x + b
271
+
272
+
273
+ def get_schedule(
274
+ num_steps: int,
275
+ image_seq_len: int,
276
+ base_shift: float = 0.5,
277
+ max_shift: float = 1.15,
278
+ shift: bool = True,
279
+ ) -> list[float]:
280
+ # extra step for zero
281
+ timesteps = torch.linspace(1, 0, num_steps + 1)
282
+
283
+ # shifting the schedule to favor high timesteps for higher signal images
284
+ if shift:
285
+ # eastimate mu based on linear estimation between two points
286
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
287
+ timesteps = time_shift(mu, 1.0, timesteps)
288
+
289
+ return timesteps.tolist()
290
+
291
+
292
+ def denoise(
293
+ model: flux_models.Flux,
294
+ img: torch.Tensor,
295
+ img_ids: torch.Tensor,
296
+ txt: torch.Tensor,
297
+ txt_ids: torch.Tensor,
298
+ vec: torch.Tensor,
299
+ timesteps: list[float],
300
+ guidance: float = 4.0,
301
+ t5_attn_mask: Optional[torch.Tensor] = None,
302
+ ):
303
+ # this is ignored for schnell
304
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
305
+ # comfy_pbar = ProgressBar(total=len(timesteps))
306
+ for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
307
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
308
+ model.prepare_block_swap_before_forward()
309
+ pred = model(
310
+ img=img,
311
+ img_ids=img_ids,
312
+ txt=txt,
313
+ txt_ids=txt_ids,
314
+ y=vec,
315
+ timesteps=t_vec,
316
+ guidance=guidance_vec,
317
+ txt_attention_mask=t5_attn_mask,
318
+ )
319
+
320
+ img = img + (t_prev - t_curr) * pred
321
+ # comfy_pbar.update(1)
322
+ model.prepare_block_swap_before_forward()
323
+ return img
324
+
325
+ # endregion
326
+
327
+
328
+ # region train
329
+ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
330
+ sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
331
+ schedule_timesteps = noise_scheduler.timesteps.to(device)
332
+ timesteps = timesteps.to(device)
333
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
334
+
335
+ sigma = sigmas[step_indices].flatten()
336
+ while len(sigma.shape) < n_dim:
337
+ sigma = sigma.unsqueeze(-1)
338
+ return sigma
339
+
340
+
341
+ def compute_density_for_timestep_sampling(
342
+ weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
343
+ ):
344
+ """Compute the density for sampling the timesteps when doing SD3 training.
345
+
346
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
347
+
348
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
349
+ """
350
+ if weighting_scheme == "logit_normal":
351
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
352
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
353
+ u = torch.nn.functional.sigmoid(u)
354
+ elif weighting_scheme == "mode":
355
+ u = torch.rand(size=(batch_size,), device="cpu")
356
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
357
+ else:
358
+ u = torch.rand(size=(batch_size,), device="cpu")
359
+ return u
360
+
361
+
362
+ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
363
+ """Computes loss weighting scheme for SD3 training.
364
+
365
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
366
+
367
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
368
+ """
369
+ if weighting_scheme == "sigma_sqrt":
370
+ weighting = (sigmas**-2.0).float()
371
+ elif weighting_scheme == "cosmap":
372
+ bot = 1 - 2 * sigmas + 2 * sigmas**2
373
+ weighting = 2 / (math.pi * bot)
374
+ else:
375
+ weighting = torch.ones_like(sigmas)
376
+ return weighting
377
+
378
+
379
+ def get_noisy_model_input_and_timesteps(
380
+ args, noise_scheduler, latents, noise, device, dtype
381
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
382
+ bsz, _, H, W = latents.shape
383
+ sigmas = None
384
+
385
+ if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
386
+ # Simple random t-based noise sampling
387
+ if args.timestep_sampling == "sigmoid":
388
+ # https://github.com/XLabs-AI/x-flux/tree/main
389
+ t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
390
+ else:
391
+ t = torch.rand((bsz,), device=device)
392
+
393
+ timesteps = t * 1000.0
394
+ t = t.view(-1, 1, 1, 1)
395
+ noisy_model_input = (1 - t) * latents + t * noise
396
+ elif args.timestep_sampling == "shift":
397
+ shift = args.discrete_flow_shift
398
+ logits_norm = torch.randn(bsz, device=device)
399
+ logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
400
+ timesteps = logits_norm.sigmoid()
401
+ timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
402
+
403
+ t = timesteps.view(-1, 1, 1, 1)
404
+ timesteps = timesteps * 1000.0
405
+ noisy_model_input = (1 - t) * latents + t * noise
406
+ elif args.timestep_sampling == "flux_shift":
407
+ logits_norm = torch.randn(bsz, device=device)
408
+ logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
409
+ timesteps = logits_norm.sigmoid()
410
+ mu=get_lin_function(y1=0.5, y2=1.15)((H//2) * (W//2))
411
+ timesteps = time_shift(mu, 1.0, timesteps)
412
+
413
+ t = timesteps.view(-1, 1, 1, 1)
414
+ timesteps = timesteps * 1000.0
415
+ noisy_model_input = (1 - t) * latents + t * noise
416
+ else:
417
+ # Sample a random timestep for each image
418
+ # for weighting schemes where we sample timesteps non-uniformly
419
+ u = compute_density_for_timestep_sampling(
420
+ weighting_scheme=args.weighting_scheme,
421
+ batch_size=bsz,
422
+ logit_mean=args.logit_mean,
423
+ logit_std=args.logit_std,
424
+ mode_scale=args.mode_scale,
425
+ )
426
+ indices = (u * noise_scheduler.config.num_train_timesteps).long()
427
+ timesteps = noise_scheduler.timesteps[indices].to(device=device)
428
+
429
+ # Add noise according to flow matching.
430
+ sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
431
+ noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
432
+
433
+ return noisy_model_input, timesteps, sigmas
434
+
435
+
436
+ def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
437
+ weighting = None
438
+ if args.model_prediction_type == "raw":
439
+ pass
440
+ elif args.model_prediction_type == "additive":
441
+ # add the model_pred to the noisy_model_input
442
+ model_pred = model_pred + noisy_model_input
443
+ elif args.model_prediction_type == "sigma_scaled":
444
+ # apply sigma scaling
445
+ model_pred = model_pred * (-sigmas) + noisy_model_input
446
+
447
+ # these weighting schemes use a uniform timestep sampling
448
+ # and instead post-weight the loss
449
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
450
+
451
+ return model_pred, weighting
452
+
453
+
454
+ def save_models(
455
+ ckpt_path: str,
456
+ flux: flux_models.Flux,
457
+ sai_metadata: Optional[dict],
458
+ save_dtype: Optional[torch.dtype] = None,
459
+ use_mem_eff_save: bool = False,
460
+ ):
461
+ state_dict = {}
462
+
463
+ def update_sd(prefix, sd):
464
+ for k, v in sd.items():
465
+ key = prefix + k
466
+ if save_dtype is not None and v.dtype != save_dtype:
467
+ v = v.detach().clone().to("cpu").to(save_dtype)
468
+ state_dict[key] = v
469
+
470
+ update_sd("", flux.state_dict())
471
+
472
+ if not use_mem_eff_save:
473
+ save_file(state_dict, ckpt_path, metadata=sai_metadata)
474
+ else:
475
+ mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata)
476
+
477
+
478
+ def save_flux_model_on_train_end(
479
+ args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, flux: flux_models.Flux
480
+ ):
481
+ def sd_saver(ckpt_file, epoch_no, global_step):
482
+ sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
483
+ save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
484
+
485
+ train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
486
+
487
+
488
+ # epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
489
+ # on_epoch_end: Trueならepoch終了時、Falseならstep経過時
490
+ def save_flux_model_on_epoch_end_or_stepwise(
491
+ args: argparse.Namespace,
492
+ on_epoch_end: bool,
493
+ accelerator,
494
+ save_dtype: torch.dtype,
495
+ epoch: int,
496
+ num_train_epochs: int,
497
+ global_step: int,
498
+ flux: flux_models.Flux,
499
+ ):
500
+ def sd_saver(ckpt_file, epoch_no, global_step):
501
+ sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
502
+ save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
503
+
504
+ train_util.save_sd_model_on_epoch_end_or_stepwise_common(
505
+ args,
506
+ on_epoch_end,
507
+ accelerator,
508
+ True,
509
+ True,
510
+ epoch,
511
+ num_train_epochs,
512
+ global_step,
513
+ sd_saver,
514
+ None,
515
+ )
516
+
517
+
518
+ # endregion
519
+
520
+
521
+ def add_flux_train_arguments(parser: argparse.ArgumentParser):
522
+ parser.add_argument(
523
+ "--clip_l",
524
+ type=str,
525
+ help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)、float16が前提",
526
+ )
527
+ parser.add_argument(
528
+ "--t5xxl",
529
+ type=str,
530
+ help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提",
531
+ )
532
+ parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)")
533
+ parser.add_argument(
534
+ "--t5xxl_max_token_length",
535
+ type=int,
536
+ default=None,
537
+ help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev"
538
+ " / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512",
539
+ )
540
+ parser.add_argument(
541
+ "--apply_t5_attn_mask",
542
+ action="store_true",
543
+ help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する",
544
+ )
545
+
546
+ parser.add_argument(
547
+ "--guidance_scale",
548
+ type=float,
549
+ default=3.5,
550
+ help="the FLUX.1 dev variant is a guidance distilled model",
551
+ )
552
+
553
+ parser.add_argument(
554
+ "--timestep_sampling",
555
+ choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
556
+ default="sigma",
557
+ help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting."
558
+ " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。",
559
+ )
560
+ parser.add_argument(
561
+ "--sigmoid_scale",
562
+ type=float,
563
+ default=1.0,
564
+ help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。',
565
+ )
566
+ parser.add_argument(
567
+ "--model_prediction_type",
568
+ choices=["raw", "additive", "sigma_scaled"],
569
+ default="sigma_scaled",
570
+ help="How to interpret and process the model prediction: "
571
+ "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."
572
+ " / モデル予測の解釈と処理方法:"
573
+ "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。",
574
+ )
575
+ parser.add_argument(
576
+ "--discrete_flow_shift",
577
+ type=float,
578
+ default=3.0,
579
+ help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
580
+ )
581
+ parser.add_argument(
582
+ "--bypass_flux_guidance"
583
+ , action="store_true"
584
+ , help="bypass flux guidance module for Flex.1-Alpha Training"
585
+ )