Upload 132 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +7 -0
- Dockerfile +53 -0
- Dockerfile.cuda12.4 +53 -0
- LICENSE +7 -0
- README.md +54 -0
- __init__.py +12 -0
- __pycache__/train_network.cpython-310.pyc +0 -0
- advanced.png +3 -0
- app-launch.sh +5 -0
- app.py +1119 -0
- datasets/1 +0 -0
- docker-compose.yml +28 -0
- fine_tune.py +560 -0
- flags.png +0 -0
- flow.gif +3 -0
- flux_extract_lora.py +221 -0
- flux_train_comfy.py +806 -0
- flux_train_network_comfy.py +500 -0
- hf_token.json +3 -0
- icon.png +0 -0
- install.js +96 -0
- library/__init__.py +0 -0
- library/__pycache__/__init__.cpython-310.pyc +0 -0
- library/__pycache__/config_util.cpython-310.pyc +0 -0
- library/__pycache__/custom_offloading_utils.cpython-310.pyc +0 -0
- library/__pycache__/custom_train_functions.cpython-310.pyc +0 -0
- library/__pycache__/deepspeed_utils.cpython-310.pyc +0 -0
- library/__pycache__/device_utils.cpython-310.pyc +0 -0
- library/__pycache__/flux_models.cpython-310.pyc +0 -0
- library/__pycache__/flux_train_utils.cpython-310.pyc +0 -0
- library/__pycache__/flux_utils.cpython-310.pyc +0 -0
- library/__pycache__/huggingface_util.cpython-310.pyc +0 -0
- library/__pycache__/model_util.cpython-310.pyc +0 -0
- library/__pycache__/original_unet.cpython-310.pyc +0 -0
- library/__pycache__/sai_model_spec.cpython-310.pyc +0 -0
- library/__pycache__/sd3_models.cpython-310.pyc +0 -0
- library/__pycache__/sd3_utils.cpython-310.pyc +0 -0
- library/__pycache__/strategy_base.cpython-310.pyc +0 -0
- library/__pycache__/strategy_sd.cpython-310.pyc +0 -0
- library/__pycache__/train_util.cpython-310.pyc +3 -0
- library/__pycache__/utils.cpython-310.pyc +0 -0
- library/adafactor_fused.py +138 -0
- library/attention_processors.py +227 -0
- library/config_util.py +717 -0
- library/custom_offloading_utils.py +227 -0
- library/custom_train_functions.py +556 -0
- library/deepspeed_utils.py +139 -0
- library/device_utils.py +84 -0
- library/flux_models.py +1060 -0
- library/flux_train_utils.py +585 -0
.gitattributes
CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
advanced.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
flow.gif filter=lfs diff=lfs merge=lfs -text
|
38 |
+
library/__pycache__/train_util.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
|
39 |
+
publish_to_hf.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
sample.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
screenshot.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
seed.gif filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Base image with CUDA 12.2
|
2 |
+
FROM nvidia/cuda:12.2.2-base-ubuntu22.04
|
3 |
+
|
4 |
+
# Install pip if not already installed
|
5 |
+
RUN apt-get update -y && apt-get install -y \
|
6 |
+
python3-pip \
|
7 |
+
python3-dev \
|
8 |
+
git \
|
9 |
+
build-essential # Install dependencies for building extensions
|
10 |
+
|
11 |
+
# Define environment variables for UID and GID and local timezone
|
12 |
+
# ENV PUID=${PUID:-1000}
|
13 |
+
# ENV PGID=${PGID:-1000}
|
14 |
+
|
15 |
+
# Create a group with the specified GID
|
16 |
+
# RUN groupadd -g "${PGID}" appuser
|
17 |
+
# Create a user with the specified UID and GID
|
18 |
+
# RUN useradd -m -s /bin/sh -u "${PUID}" -g "${PGID}" appuser
|
19 |
+
|
20 |
+
WORKDIR /app
|
21 |
+
|
22 |
+
# Get sd-scripts from kohya-ss and install them
|
23 |
+
RUN git clone -b sd3 https://github.com/kohya-ss/sd-scripts && \
|
24 |
+
cd sd-scripts && \
|
25 |
+
pip install --no-cache-dir -r ./requirements.txt
|
26 |
+
|
27 |
+
# Install main application dependencies
|
28 |
+
COPY ./requirements.txt ./requirements.txt
|
29 |
+
RUN pip install --no-cache-dir -r ./requirements.txt
|
30 |
+
|
31 |
+
# Install Torch, Torchvision, and Torchaudio for CUDA 12.2
|
32 |
+
RUN pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu122/torch_stable.html
|
33 |
+
|
34 |
+
RUN chown -R appuser:appuser /app
|
35 |
+
|
36 |
+
# delete redundant requirements.txt and sd-scripts directory within the container
|
37 |
+
RUN rm -r ./sd-scripts
|
38 |
+
RUN rm ./requirements.txt
|
39 |
+
RUN pip install --force-reinstall -v "triton==3.1.0"
|
40 |
+
#Run application as non-root
|
41 |
+
# USER appuser
|
42 |
+
|
43 |
+
# Copy fluxgym application code
|
44 |
+
COPY . ./fluxgym
|
45 |
+
|
46 |
+
EXPOSE 7860
|
47 |
+
|
48 |
+
ENV GRADIO_SERVER_NAME="0.0.0.0"
|
49 |
+
|
50 |
+
WORKDIR /app/fluxgym
|
51 |
+
|
52 |
+
# Run fluxgym Python application
|
53 |
+
CMD ["python3", "./app.py"]
|
Dockerfile.cuda12.4
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Base image with CUDA 12.4
|
2 |
+
FROM nvidia/cuda:12.4.1-base-ubuntu22.04
|
3 |
+
|
4 |
+
# Install pip if not already installed
|
5 |
+
RUN apt-get update -y && apt-get install -y \
|
6 |
+
python3-pip \
|
7 |
+
python3-dev \
|
8 |
+
git \
|
9 |
+
build-essential # Install dependencies for building extensions
|
10 |
+
|
11 |
+
# Define environment variables for UID and GID and local timezone
|
12 |
+
ENV PUID=${PUID:-1000}
|
13 |
+
ENV PGID=${PGID:-1000}
|
14 |
+
|
15 |
+
# Create a group with the specified GID
|
16 |
+
RUN groupadd -g "${PGID}" appuser
|
17 |
+
# Create a user with the specified UID and GID
|
18 |
+
RUN useradd -m -s /bin/sh -u "${PUID}" -g "${PGID}" appuser
|
19 |
+
|
20 |
+
WORKDIR /app
|
21 |
+
|
22 |
+
# Get sd-scripts from kohya-ss and install them
|
23 |
+
RUN git clone -b sd3 https://github.com/kohya-ss/sd-scripts && \
|
24 |
+
cd sd-scripts && \
|
25 |
+
pip install --no-cache-dir -r ./requirements.txt
|
26 |
+
|
27 |
+
# Install main application dependencies
|
28 |
+
COPY ./requirements.txt ./requirements.txt
|
29 |
+
RUN pip install --no-cache-dir -r ./requirements.txt
|
30 |
+
|
31 |
+
# Install Torch, Torchvision, and Torchaudio for CUDA 12.4
|
32 |
+
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
|
33 |
+
|
34 |
+
RUN chown -R appuser:appuser /app
|
35 |
+
|
36 |
+
# delete redundant requirements.txt and sd-scripts directory within the container
|
37 |
+
RUN rm -r ./sd-scripts
|
38 |
+
RUN rm ./requirements.txt
|
39 |
+
|
40 |
+
#Run application as non-root
|
41 |
+
USER appuser
|
42 |
+
|
43 |
+
# Copy fluxgym application code
|
44 |
+
COPY . ./fluxgym
|
45 |
+
|
46 |
+
EXPOSE 7860
|
47 |
+
|
48 |
+
ENV GRADIO_SERVER_NAME="0.0.0.0"
|
49 |
+
|
50 |
+
WORKDIR /app/fluxgym
|
51 |
+
|
52 |
+
# Run fluxgym Python application
|
53 |
+
CMD ["python3", "./app.py"]
|
LICENSE
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright 2024 cocktailpeanut
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
4 |
+
|
5 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
6 |
+
|
7 |
+
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ComfyUI Flux Trainer
|
2 |
+
|
3 |
+
Wrapper for slightly modified kohya's training scripts: https://github.com/kohya-ss/sd-scripts
|
4 |
+
|
5 |
+
Including code from: https://github.com/KohakuBlueleaf/Lycoris
|
6 |
+
|
7 |
+
And https://github.com/LoganBooker/prodigy-plus-schedule-free
|
8 |
+
|
9 |
+
## DISCLAIMER:
|
10 |
+
I have **very** little previous experience in training anything, Flux is basically first model I've been inspired to learn. Previously I've only trained AnimateDiff Motion Loras, and built similar training nodes for it.
|
11 |
+
|
12 |
+
## DO NOT ASK ME FOR TRAINING ADVICE
|
13 |
+
I can not emphasize this enough, this repository is not for raising questions related to the training itself, that would be better done to kohya's repo. Even so keep in mind my implementation may have mistakes.
|
14 |
+
|
15 |
+
The default settings aren't necessarily any good, they are just the last (out of many) I've tried and worked for my dataset.
|
16 |
+
|
17 |
+
# THIS IS EXPERIMENTAL
|
18 |
+
Both these nodes and the underlaying implementation by kohya is work in progress and expected to change.
|
19 |
+
|
20 |
+
# Installation
|
21 |
+
1. Clone this repo into `custom_nodes` folder.
|
22 |
+
2. Install dependencies: `pip install -r requirements.txt`
|
23 |
+
or if you use the portable install, run this in ComfyUI_windows_portable -folder:
|
24 |
+
|
25 |
+
`python_embeded\python.exe -m pip install -r ComfyUI\custom_nodes\ComfyUI-FluxTrainer\requirements.txt`
|
26 |
+
|
27 |
+
In addition torch version 2.4.0 or higher is highly recommended.
|
28 |
+
|
29 |
+
Example workflow for LoRA training can be found in the examples folder, it utilizes additional nodes from:
|
30 |
+
|
31 |
+
https://github.com/kijai/ComfyUI-KJNodes
|
32 |
+
|
33 |
+
And some (optional) debugging nodes from:
|
34 |
+
|
35 |
+
https://github.com/rgthree/rgthree-comfy
|
36 |
+
|
37 |
+
For LoRA training the models need to be the normal fp8 or fp16 versions, also make sure the VAE is the non-diffusers version:
|
38 |
+
|
39 |
+
https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/ae.safetensors
|
40 |
+
|
41 |
+
For full model training the fp16 version of the main model needs to be used.
|
42 |
+
|
43 |
+
## Why train in ComfyUI?
|
44 |
+
- Familiar UI (obviously only if you are a Comfy user already)
|
45 |
+
- You can use same models you use for inference
|
46 |
+
- You can use same python environment, I faced no incompabilities
|
47 |
+
- You can build workflows to compare settings etc.
|
48 |
+
|
49 |
+
Currently supports LoRA training, and untested full finetune with code from kohya's scripts: https://github.com/kohya-ss/sd-scripts
|
50 |
+
|
51 |
+
Experimental support for LyCORIS training has been added as well, using code from: https://github.com/KohakuBlueleaf/Lycoris
|
52 |
+
|
53 |
+

|
54 |
+
|
__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
2 |
+
from .nodes_sd3 import NODE_CLASS_MAPPINGS as NODE_CLASS_MAPPINGS_SD3
|
3 |
+
from .nodes_sd3 import NODE_DISPLAY_NAME_MAPPINGS as NODE_DISPLAY_NAME_MAPPINGS_SD3
|
4 |
+
from .nodes_sdxl import NODE_CLASS_MAPPINGS as NODE_CLASS_MAPPINGS_SDXL
|
5 |
+
from .nodes_sdxl import NODE_DISPLAY_NAME_MAPPINGS as NODE_DISPLAY_NAME_MAPPINGS_SDXL
|
6 |
+
|
7 |
+
NODE_CLASS_MAPPINGS.update(NODE_CLASS_MAPPINGS_SD3)
|
8 |
+
NODE_CLASS_MAPPINGS.update(NODE_CLASS_MAPPINGS_SDXL)
|
9 |
+
NODE_DISPLAY_NAME_MAPPINGS.update(NODE_DISPLAY_NAME_MAPPINGS_SD3)
|
10 |
+
NODE_DISPLAY_NAME_MAPPINGS.update(NODE_DISPLAY_NAME_MAPPINGS_SDXL)
|
11 |
+
|
12 |
+
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
|
__pycache__/train_network.cpython-310.pyc
ADDED
Binary file (38.9 kB). View file
|
|
advanced.png
ADDED
![]() |
Git LFS Details
|
app-launch.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
cd "`dirname "$0"`" || exit 1
|
4 |
+
. env/bin/activate
|
5 |
+
python app.py
|
app.py
ADDED
@@ -0,0 +1,1119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
4 |
+
os.environ['GRADIO_ANALYTICS_ENABLED'] = '0'
|
5 |
+
sys.path.insert(0, os.getcwd())
|
6 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), 'sd-scripts'))
|
7 |
+
import subprocess
|
8 |
+
import gradio as gr
|
9 |
+
from PIL import Image
|
10 |
+
import torch
|
11 |
+
import uuid
|
12 |
+
import shutil
|
13 |
+
import json
|
14 |
+
import yaml
|
15 |
+
from slugify import slugify
|
16 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
17 |
+
from gradio_logsview import LogsView, LogsViewRunner
|
18 |
+
from huggingface_hub import hf_hub_download, HfApi
|
19 |
+
from library import flux_train_utils, huggingface_util
|
20 |
+
from argparse import Namespace
|
21 |
+
import train_network
|
22 |
+
import toml
|
23 |
+
import re
|
24 |
+
MAX_IMAGES = 150
|
25 |
+
|
26 |
+
with open('models.yaml', 'r') as file:
|
27 |
+
models = yaml.safe_load(file)
|
28 |
+
|
29 |
+
def readme(base_model, lora_name, instance_prompt, sample_prompts):
|
30 |
+
|
31 |
+
# model license
|
32 |
+
model_config = models[base_model]
|
33 |
+
model_file = model_config["file"]
|
34 |
+
base_model_name = model_config["base"]
|
35 |
+
license = None
|
36 |
+
license_name = None
|
37 |
+
license_link = None
|
38 |
+
license_items = []
|
39 |
+
if "license" in model_config:
|
40 |
+
license = model_config["license"]
|
41 |
+
license_items.append(f"license: {license}")
|
42 |
+
if "license_name" in model_config:
|
43 |
+
license_name = model_config["license_name"]
|
44 |
+
license_items.append(f"license_name: {license_name}")
|
45 |
+
if "license_link" in model_config:
|
46 |
+
license_link = model_config["license_link"]
|
47 |
+
license_items.append(f"license_link: {license_link}")
|
48 |
+
license_str = "\n".join(license_items)
|
49 |
+
print(f"license_items={license_items}")
|
50 |
+
print(f"license_str = {license_str}")
|
51 |
+
|
52 |
+
# tags
|
53 |
+
tags = [ "text-to-image", "flux", "lora", "diffusers", "template:sd-lora", "fluxgym" ]
|
54 |
+
|
55 |
+
# widgets
|
56 |
+
widgets = []
|
57 |
+
sample_image_paths = []
|
58 |
+
output_name = slugify(lora_name)
|
59 |
+
samples_dir = resolve_path_without_quotes(f"outputs/{output_name}/sample")
|
60 |
+
try:
|
61 |
+
for filename in os.listdir(samples_dir):
|
62 |
+
# Filename Schema: [name]_[steps]_[index]_[timestamp].png
|
63 |
+
match = re.search(r"_(\d+)_(\d+)_(\d+)\.png$", filename)
|
64 |
+
if match:
|
65 |
+
steps, index, timestamp = int(match.group(1)), int(match.group(2)), int(match.group(3))
|
66 |
+
sample_image_paths.append((steps, index, f"sample/{filename}"))
|
67 |
+
|
68 |
+
# Sort by numeric index
|
69 |
+
sample_image_paths.sort(key=lambda x: x[0], reverse=True)
|
70 |
+
|
71 |
+
final_sample_image_paths = sample_image_paths[:len(sample_prompts)]
|
72 |
+
final_sample_image_paths.sort(key=lambda x: x[1])
|
73 |
+
for i, prompt in enumerate(sample_prompts):
|
74 |
+
_, _, image_path = final_sample_image_paths[i]
|
75 |
+
widgets.append(
|
76 |
+
{
|
77 |
+
"text": prompt,
|
78 |
+
"output": {
|
79 |
+
"url": image_path
|
80 |
+
},
|
81 |
+
}
|
82 |
+
)
|
83 |
+
except:
|
84 |
+
print(f"no samples")
|
85 |
+
dtype = "torch.bfloat16"
|
86 |
+
# Construct the README content
|
87 |
+
readme_content = f"""---
|
88 |
+
tags:
|
89 |
+
{yaml.dump(tags, indent=4).strip()}
|
90 |
+
{"widget:" if os.path.isdir(samples_dir) else ""}
|
91 |
+
{yaml.dump(widgets, indent=4).strip() if widgets else ""}
|
92 |
+
base_model: {base_model_name}
|
93 |
+
{"instance_prompt: " + instance_prompt if instance_prompt else ""}
|
94 |
+
{license_str}
|
95 |
+
---
|
96 |
+
|
97 |
+
# {lora_name}
|
98 |
+
|
99 |
+
A Flux LoRA trained on a local computer with [Fluxgym](https://github.com/cocktailpeanut/fluxgym)
|
100 |
+
|
101 |
+
<Gallery />
|
102 |
+
|
103 |
+
## Trigger words
|
104 |
+
|
105 |
+
{"You should use `" + instance_prompt + "` to trigger the image generation." if instance_prompt else "No trigger words defined."}
|
106 |
+
|
107 |
+
## Download model and use it with ComfyUI, AUTOMATIC1111, SD.Next, Invoke AI, Forge, etc.
|
108 |
+
|
109 |
+
Weights for this model are available in Safetensors format.
|
110 |
+
|
111 |
+
"""
|
112 |
+
return readme_content
|
113 |
+
|
114 |
+
def account_hf():
|
115 |
+
try:
|
116 |
+
with open("HF_TOKEN", "r") as file:
|
117 |
+
token = file.read()
|
118 |
+
api = HfApi(token=token)
|
119 |
+
try:
|
120 |
+
account = api.whoami()
|
121 |
+
return { "token": token, "account": account['name'] }
|
122 |
+
except:
|
123 |
+
return None
|
124 |
+
except:
|
125 |
+
return None
|
126 |
+
|
127 |
+
"""
|
128 |
+
hf_logout.click(fn=logout_hf, outputs=[hf_token, hf_login, hf_logout, repo_owner])
|
129 |
+
"""
|
130 |
+
def logout_hf():
|
131 |
+
os.remove("HF_TOKEN")
|
132 |
+
global current_account
|
133 |
+
current_account = account_hf()
|
134 |
+
print(f"current_account={current_account}")
|
135 |
+
return gr.update(value=""), gr.update(visible=True), gr.update(visible=False), gr.update(value="", visible=False)
|
136 |
+
|
137 |
+
|
138 |
+
"""
|
139 |
+
hf_login.click(fn=login_hf, inputs=[hf_token], outputs=[hf_token, hf_login, hf_logout, repo_owner])
|
140 |
+
"""
|
141 |
+
def login_hf(hf_token):
|
142 |
+
api = HfApi(token=hf_token)
|
143 |
+
try:
|
144 |
+
account = api.whoami()
|
145 |
+
if account != None:
|
146 |
+
if "name" in account:
|
147 |
+
with open("HF_TOKEN", "w") as file:
|
148 |
+
file.write(hf_token)
|
149 |
+
global current_account
|
150 |
+
current_account = account_hf()
|
151 |
+
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(value=current_account["account"], visible=True)
|
152 |
+
return gr.update(), gr.update(), gr.update(), gr.update()
|
153 |
+
except:
|
154 |
+
print(f"incorrect hf_token")
|
155 |
+
return gr.update(), gr.update(), gr.update(), gr.update()
|
156 |
+
|
157 |
+
def upload_hf(base_model, lora_rows, repo_owner, repo_name, repo_visibility, hf_token):
|
158 |
+
src = lora_rows
|
159 |
+
repo_id = f"{repo_owner}/{repo_name}"
|
160 |
+
gr.Info(f"Uploading to Huggingface. Please Stand by...", duration=None)
|
161 |
+
args = Namespace(
|
162 |
+
huggingface_repo_id=repo_id,
|
163 |
+
huggingface_repo_type="model",
|
164 |
+
huggingface_repo_visibility=repo_visibility,
|
165 |
+
huggingface_path_in_repo="",
|
166 |
+
huggingface_token=hf_token,
|
167 |
+
async_upload=False
|
168 |
+
)
|
169 |
+
print(f"upload_hf args={args}")
|
170 |
+
huggingface_util.upload(args=args, src=src)
|
171 |
+
gr.Info(f"[Upload Complete] https://huggingface.co/{repo_id}", duration=None)
|
172 |
+
|
173 |
+
def load_captioning(uploaded_files, concept_sentence):
|
174 |
+
uploaded_images = [file for file in uploaded_files if not file.endswith('.txt')]
|
175 |
+
txt_files = [file for file in uploaded_files if file.endswith('.txt')]
|
176 |
+
txt_files_dict = {os.path.splitext(os.path.basename(txt_file))[0]: txt_file for txt_file in txt_files}
|
177 |
+
updates = []
|
178 |
+
if len(uploaded_images) <= 1:
|
179 |
+
raise gr.Error(
|
180 |
+
"Please upload at least 2 images to train your model (the ideal number with default settings is between 4-30)"
|
181 |
+
)
|
182 |
+
elif len(uploaded_images) > MAX_IMAGES:
|
183 |
+
raise gr.Error(f"For now, only {MAX_IMAGES} or less images are allowed for training")
|
184 |
+
# Update for the captioning_area
|
185 |
+
# for _ in range(3):
|
186 |
+
updates.append(gr.update(visible=True))
|
187 |
+
# Update visibility and image for each captioning row and image
|
188 |
+
for i in range(1, MAX_IMAGES + 1):
|
189 |
+
# Determine if the current row and image should be visible
|
190 |
+
visible = i <= len(uploaded_images)
|
191 |
+
|
192 |
+
# Update visibility of the captioning row
|
193 |
+
updates.append(gr.update(visible=visible))
|
194 |
+
|
195 |
+
# Update for image component - display image if available, otherwise hide
|
196 |
+
image_value = uploaded_images[i - 1] if visible else None
|
197 |
+
updates.append(gr.update(value=image_value, visible=visible))
|
198 |
+
|
199 |
+
corresponding_caption = False
|
200 |
+
if(image_value):
|
201 |
+
base_name = os.path.splitext(os.path.basename(image_value))[0]
|
202 |
+
if base_name in txt_files_dict:
|
203 |
+
with open(txt_files_dict[base_name], 'r') as file:
|
204 |
+
corresponding_caption = file.read()
|
205 |
+
|
206 |
+
# Update value of captioning area
|
207 |
+
text_value = corresponding_caption if visible and corresponding_caption else concept_sentence if visible and concept_sentence else None
|
208 |
+
updates.append(gr.update(value=text_value, visible=visible))
|
209 |
+
|
210 |
+
# Update for the sample caption area
|
211 |
+
updates.append(gr.update(visible=True))
|
212 |
+
updates.append(gr.update(visible=True))
|
213 |
+
|
214 |
+
return updates
|
215 |
+
|
216 |
+
def hide_captioning():
|
217 |
+
return gr.update(visible=False), gr.update(visible=False)
|
218 |
+
|
219 |
+
def resize_image(image_path, output_path, size):
|
220 |
+
with Image.open(image_path) as img:
|
221 |
+
width, height = img.size
|
222 |
+
if width < height:
|
223 |
+
new_width = size
|
224 |
+
new_height = int((size/width) * height)
|
225 |
+
else:
|
226 |
+
new_height = size
|
227 |
+
new_width = int((size/height) * width)
|
228 |
+
print(f"resize {image_path} : {new_width}x{new_height}")
|
229 |
+
img_resized = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
230 |
+
img_resized.save(output_path)
|
231 |
+
|
232 |
+
def create_dataset(destination_folder, size, *inputs):
|
233 |
+
print("Creating dataset")
|
234 |
+
images = inputs[0]
|
235 |
+
if not os.path.exists(destination_folder):
|
236 |
+
os.makedirs(destination_folder)
|
237 |
+
|
238 |
+
for index, image in enumerate(images):
|
239 |
+
# copy the images to the datasets folder
|
240 |
+
new_image_path = shutil.copy(image, destination_folder)
|
241 |
+
|
242 |
+
# if it's a caption text file skip the next bit
|
243 |
+
ext = os.path.splitext(new_image_path)[-1].lower()
|
244 |
+
if ext == '.txt':
|
245 |
+
continue
|
246 |
+
|
247 |
+
# resize the images
|
248 |
+
resize_image(new_image_path, new_image_path, size)
|
249 |
+
|
250 |
+
# copy the captions
|
251 |
+
|
252 |
+
original_caption = inputs[index + 1]
|
253 |
+
|
254 |
+
image_file_name = os.path.basename(new_image_path)
|
255 |
+
caption_file_name = os.path.splitext(image_file_name)[0] + ".txt"
|
256 |
+
caption_path = resolve_path_without_quotes(os.path.join(destination_folder, caption_file_name))
|
257 |
+
print(f"image_path={new_image_path}, caption_path = {caption_path}, original_caption={original_caption}")
|
258 |
+
# if caption_path exists, do not write
|
259 |
+
if os.path.exists(caption_path):
|
260 |
+
print(f"{caption_path} already exists. use the existing .txt file")
|
261 |
+
else:
|
262 |
+
print(f"{caption_path} create a .txt caption file")
|
263 |
+
with open(caption_path, 'w') as file:
|
264 |
+
file.write(original_caption)
|
265 |
+
|
266 |
+
print(f"destination_folder {destination_folder}")
|
267 |
+
return destination_folder
|
268 |
+
|
269 |
+
|
270 |
+
def run_captioning(images, concept_sentence, *captions):
|
271 |
+
print(f"run_captioning")
|
272 |
+
print(f"concept sentence {concept_sentence}")
|
273 |
+
print(f"captions {captions}")
|
274 |
+
#Load internally to not consume resources for training
|
275 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
276 |
+
print(f"device={device}")
|
277 |
+
torch_dtype = torch.float16
|
278 |
+
model = AutoModelForCausalLM.from_pretrained(
|
279 |
+
"multimodalart/Florence-2-large-no-flash-attn", torch_dtype=torch_dtype, trust_remote_code=True
|
280 |
+
).to(device)
|
281 |
+
processor = AutoProcessor.from_pretrained("multimodalart/Florence-2-large-no-flash-attn", trust_remote_code=True)
|
282 |
+
|
283 |
+
captions = list(captions)
|
284 |
+
for i, image_path in enumerate(images):
|
285 |
+
print(captions[i])
|
286 |
+
if isinstance(image_path, str): # If image is a file path
|
287 |
+
image = Image.open(image_path).convert("RGB")
|
288 |
+
|
289 |
+
prompt = "<DETAILED_CAPTION>"
|
290 |
+
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
|
291 |
+
print(f"inputs {inputs}")
|
292 |
+
|
293 |
+
generated_ids = model.generate(
|
294 |
+
input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3
|
295 |
+
)
|
296 |
+
print(f"generated_ids {generated_ids}")
|
297 |
+
|
298 |
+
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
299 |
+
print(f"generated_text: {generated_text}")
|
300 |
+
parsed_answer = processor.post_process_generation(
|
301 |
+
generated_text, task=prompt, image_size=(image.width, image.height)
|
302 |
+
)
|
303 |
+
print(f"parsed_answer = {parsed_answer}")
|
304 |
+
caption_text = parsed_answer["<DETAILED_CAPTION>"].replace("The image shows ", "")
|
305 |
+
print(f"caption_text = {caption_text}, concept_sentence={concept_sentence}")
|
306 |
+
if concept_sentence:
|
307 |
+
caption_text = f"{concept_sentence} {caption_text}"
|
308 |
+
captions[i] = caption_text
|
309 |
+
|
310 |
+
yield captions
|
311 |
+
model.to("cpu")
|
312 |
+
del model
|
313 |
+
del processor
|
314 |
+
if torch.cuda.is_available():
|
315 |
+
torch.cuda.empty_cache()
|
316 |
+
|
317 |
+
def recursive_update(d, u):
|
318 |
+
for k, v in u.items():
|
319 |
+
if isinstance(v, dict) and v:
|
320 |
+
d[k] = recursive_update(d.get(k, {}), v)
|
321 |
+
else:
|
322 |
+
d[k] = v
|
323 |
+
return d
|
324 |
+
|
325 |
+
def download(base_model):
|
326 |
+
model = models[base_model]
|
327 |
+
model_file = model["file"]
|
328 |
+
repo = model["repo"]
|
329 |
+
|
330 |
+
# download unet
|
331 |
+
if base_model == "flux-dev" or base_model == "flux-schnell":
|
332 |
+
unet_folder = "models/unet"
|
333 |
+
else:
|
334 |
+
unet_folder = f"models/unet/{repo}"
|
335 |
+
unet_path = os.path.join(unet_folder, model_file)
|
336 |
+
if not os.path.exists(unet_path):
|
337 |
+
os.makedirs(unet_folder, exist_ok=True)
|
338 |
+
gr.Info(f"Downloading base model: {base_model}. Please wait. (You can check the terminal for the download progress)", duration=None)
|
339 |
+
print(f"download {base_model}")
|
340 |
+
hf_hub_download(repo_id=repo, local_dir=unet_folder, filename=model_file)
|
341 |
+
|
342 |
+
# download vae
|
343 |
+
vae_folder = "models/vae"
|
344 |
+
vae_path = os.path.join(vae_folder, "ae.sft")
|
345 |
+
if not os.path.exists(vae_path):
|
346 |
+
os.makedirs(vae_folder, exist_ok=True)
|
347 |
+
gr.Info(f"Downloading vae")
|
348 |
+
print(f"downloading ae.sft...")
|
349 |
+
hf_hub_download(repo_id="cocktailpeanut/xulf-dev", local_dir=vae_folder, filename="ae.sft")
|
350 |
+
|
351 |
+
# download clip
|
352 |
+
clip_folder = "models/clip"
|
353 |
+
clip_l_path = os.path.join(clip_folder, "clip_l.safetensors")
|
354 |
+
if not os.path.exists(clip_l_path):
|
355 |
+
os.makedirs(clip_folder, exist_ok=True)
|
356 |
+
gr.Info(f"Downloading clip...")
|
357 |
+
print(f"download clip_l.safetensors")
|
358 |
+
hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", local_dir=clip_folder, filename="clip_l.safetensors")
|
359 |
+
|
360 |
+
# download t5xxl
|
361 |
+
t5xxl_path = os.path.join(clip_folder, "t5xxl_fp16.safetensors")
|
362 |
+
if not os.path.exists(t5xxl_path):
|
363 |
+
print(f"download t5xxl_fp16.safetensors")
|
364 |
+
gr.Info(f"Downloading t5xxl...")
|
365 |
+
hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", local_dir=clip_folder, filename="t5xxl_fp16.safetensors")
|
366 |
+
|
367 |
+
|
368 |
+
def resolve_path(p):
|
369 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
370 |
+
norm_path = os.path.normpath(os.path.join(current_dir, p))
|
371 |
+
return f"\"{norm_path}\""
|
372 |
+
def resolve_path_without_quotes(p):
|
373 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
374 |
+
norm_path = os.path.normpath(os.path.join(current_dir, p))
|
375 |
+
return norm_path
|
376 |
+
|
377 |
+
def gen_sh(
|
378 |
+
base_model,
|
379 |
+
output_name,
|
380 |
+
resolution,
|
381 |
+
seed,
|
382 |
+
workers,
|
383 |
+
learning_rate,
|
384 |
+
network_dim,
|
385 |
+
max_train_epochs,
|
386 |
+
save_every_n_epochs,
|
387 |
+
timestep_sampling,
|
388 |
+
guidance_scale,
|
389 |
+
vram,
|
390 |
+
sample_prompts,
|
391 |
+
sample_every_n_steps,
|
392 |
+
*advanced_components
|
393 |
+
):
|
394 |
+
|
395 |
+
print(f"gen_sh: network_dim:{network_dim}, max_train_epochs={max_train_epochs}, save_every_n_epochs={save_every_n_epochs}, timestep_sampling={timestep_sampling}, guidance_scale={guidance_scale}, vram={vram}, sample_prompts={sample_prompts}, sample_every_n_steps={sample_every_n_steps}")
|
396 |
+
|
397 |
+
output_dir = resolve_path(f"outputs/{output_name}")
|
398 |
+
sample_prompts_path = resolve_path(f"outputs/{output_name}/sample_prompts.txt")
|
399 |
+
|
400 |
+
line_break = "\\"
|
401 |
+
file_type = "sh"
|
402 |
+
if sys.platform == "win32":
|
403 |
+
line_break = "^"
|
404 |
+
file_type = "bat"
|
405 |
+
|
406 |
+
############# Sample args ########################
|
407 |
+
sample = ""
|
408 |
+
if len(sample_prompts) > 0 and sample_every_n_steps > 0:
|
409 |
+
sample = f"""--sample_prompts={sample_prompts_path} --sample_every_n_steps="{sample_every_n_steps}" {line_break}"""
|
410 |
+
|
411 |
+
|
412 |
+
############# Optimizer args ########################
|
413 |
+
# if vram == "8G":
|
414 |
+
# optimizer = f"""--optimizer_type adafactor {line_break}
|
415 |
+
# --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break}
|
416 |
+
# --split_mode {line_break}
|
417 |
+
# --network_args "train_blocks=single" {line_break}
|
418 |
+
# --lr_scheduler constant_with_warmup {line_break}
|
419 |
+
# --max_grad_norm 0.0 {line_break}"""
|
420 |
+
if vram == "16G":
|
421 |
+
# 16G VRAM
|
422 |
+
optimizer = f"""--optimizer_type adafactor {line_break}
|
423 |
+
--optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break}
|
424 |
+
--lr_scheduler constant_with_warmup {line_break}
|
425 |
+
--max_grad_norm 0.0 {line_break}"""
|
426 |
+
elif vram == "12G":
|
427 |
+
# 12G VRAM
|
428 |
+
optimizer = f"""--optimizer_type adafactor {line_break}
|
429 |
+
--optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break}
|
430 |
+
--split_mode {line_break}
|
431 |
+
--network_args "train_blocks=single" {line_break}
|
432 |
+
--lr_scheduler constant_with_warmup {line_break}
|
433 |
+
--max_grad_norm 0.0 {line_break}"""
|
434 |
+
else:
|
435 |
+
# 20G+ VRAM
|
436 |
+
optimizer = f"--optimizer_type adamw8bit {line_break}"
|
437 |
+
|
438 |
+
|
439 |
+
#######################################################
|
440 |
+
model_config = models[base_model]
|
441 |
+
model_file = model_config["file"]
|
442 |
+
repo = model_config["repo"]
|
443 |
+
if base_model == "flux-dev" or base_model == "flux-schnell":
|
444 |
+
model_folder = "models/unet"
|
445 |
+
else:
|
446 |
+
model_folder = f"models/unet/{repo}"
|
447 |
+
model_path = os.path.join(model_folder, model_file)
|
448 |
+
pretrained_model_path = resolve_path(model_path)
|
449 |
+
|
450 |
+
clip_path = resolve_path("models/clip/clip_l.safetensors")
|
451 |
+
t5_path = resolve_path("models/clip/t5xxl_fp16.safetensors")
|
452 |
+
ae_path = resolve_path("models/vae/ae.sft")
|
453 |
+
sh = f"""accelerate launch {line_break}
|
454 |
+
--mixed_precision bf16 {line_break}
|
455 |
+
--num_cpu_threads_per_process 1 {line_break}
|
456 |
+
sd-scripts/flux_train_network.py {line_break}
|
457 |
+
--pretrained_model_name_or_path {pretrained_model_path} {line_break}
|
458 |
+
--clip_l {clip_path} {line_break}
|
459 |
+
--t5xxl {t5_path} {line_break}
|
460 |
+
--ae {ae_path} {line_break}
|
461 |
+
--cache_latents_to_disk {line_break}
|
462 |
+
--save_model_as safetensors {line_break}
|
463 |
+
--sdpa --persistent_data_loader_workers {line_break}
|
464 |
+
--max_data_loader_n_workers {workers} {line_break}
|
465 |
+
--seed {seed} {line_break}
|
466 |
+
--gradient_checkpointing {line_break}
|
467 |
+
--mixed_precision bf16 {line_break}
|
468 |
+
--save_precision bf16 {line_break}
|
469 |
+
--network_module networks.lora_flux {line_break}
|
470 |
+
--network_dim {network_dim} {line_break}
|
471 |
+
{optimizer}{sample}
|
472 |
+
--learning_rate {learning_rate} {line_break}
|
473 |
+
--cache_text_encoder_outputs {line_break}
|
474 |
+
--cache_text_encoder_outputs_to_disk {line_break}
|
475 |
+
--fp8_base {line_break}
|
476 |
+
--highvram {line_break}
|
477 |
+
--max_train_epochs {max_train_epochs} {line_break}
|
478 |
+
--save_every_n_epochs {save_every_n_epochs} {line_break}
|
479 |
+
--dataset_config {resolve_path(f"outputs/{output_name}/dataset.toml")} {line_break}
|
480 |
+
--output_dir {output_dir} {line_break}
|
481 |
+
--output_name {output_name} {line_break}
|
482 |
+
--timestep_sampling {timestep_sampling} {line_break}
|
483 |
+
--discrete_flow_shift 3.1582 {line_break}
|
484 |
+
--model_prediction_type raw {line_break}
|
485 |
+
--guidance_scale {guidance_scale} {line_break}
|
486 |
+
--loss_type l2 {line_break}"""
|
487 |
+
|
488 |
+
|
489 |
+
|
490 |
+
############# Advanced args ########################
|
491 |
+
global advanced_component_ids
|
492 |
+
global original_advanced_component_values
|
493 |
+
|
494 |
+
# check dirty
|
495 |
+
print(f"original_advanced_component_values = {original_advanced_component_values}")
|
496 |
+
advanced_flags = []
|
497 |
+
for i, current_value in enumerate(advanced_components):
|
498 |
+
# print(f"compare {advanced_component_ids[i]}: old={original_advanced_component_values[i]}, new={current_value}")
|
499 |
+
if original_advanced_component_values[i] != current_value:
|
500 |
+
# dirty
|
501 |
+
if current_value == True:
|
502 |
+
# Boolean
|
503 |
+
advanced_flags.append(advanced_component_ids[i])
|
504 |
+
else:
|
505 |
+
# string
|
506 |
+
advanced_flags.append(f"{advanced_component_ids[i]} {current_value}")
|
507 |
+
|
508 |
+
if len(advanced_flags) > 0:
|
509 |
+
advanced_flags_str = f" {line_break}\n ".join(advanced_flags)
|
510 |
+
sh = sh + "\n " + advanced_flags_str
|
511 |
+
|
512 |
+
return sh
|
513 |
+
|
514 |
+
def gen_toml(
|
515 |
+
dataset_folder,
|
516 |
+
resolution,
|
517 |
+
class_tokens,
|
518 |
+
num_repeats
|
519 |
+
):
|
520 |
+
toml = f"""[general]
|
521 |
+
shuffle_caption = false
|
522 |
+
caption_extension = '.txt'
|
523 |
+
keep_tokens = 1
|
524 |
+
|
525 |
+
[[datasets]]
|
526 |
+
resolution = {resolution}
|
527 |
+
batch_size = 1
|
528 |
+
keep_tokens = 1
|
529 |
+
|
530 |
+
[[datasets.subsets]]
|
531 |
+
image_dir = '{resolve_path_without_quotes(dataset_folder)}'
|
532 |
+
class_tokens = '{class_tokens}'
|
533 |
+
num_repeats = {num_repeats}"""
|
534 |
+
return toml
|
535 |
+
|
536 |
+
def update_total_steps(max_train_epochs, num_repeats, images):
|
537 |
+
try:
|
538 |
+
num_images = len(images)
|
539 |
+
total_steps = max_train_epochs * num_images * num_repeats
|
540 |
+
print(f"max_train_epochs={max_train_epochs} num_images={num_images}, num_repeats={num_repeats}, total_steps={total_steps}")
|
541 |
+
return gr.update(value = total_steps)
|
542 |
+
except:
|
543 |
+
print("")
|
544 |
+
|
545 |
+
def set_repo(lora_rows):
|
546 |
+
selected_name = os.path.basename(lora_rows)
|
547 |
+
return gr.update(value=selected_name)
|
548 |
+
|
549 |
+
def get_loras():
|
550 |
+
try:
|
551 |
+
outputs_path = resolve_path_without_quotes(f"outputs")
|
552 |
+
files = os.listdir(outputs_path)
|
553 |
+
folders = [os.path.join(outputs_path, item) for item in files if os.path.isdir(os.path.join(outputs_path, item)) and item != "sample"]
|
554 |
+
folders.sort(key=lambda file: os.path.getctime(file), reverse=True)
|
555 |
+
return folders
|
556 |
+
except Exception as e:
|
557 |
+
return []
|
558 |
+
|
559 |
+
def get_samples(lora_name):
|
560 |
+
output_name = slugify(lora_name)
|
561 |
+
try:
|
562 |
+
samples_path = resolve_path_without_quotes(f"outputs/{output_name}/sample")
|
563 |
+
files = [os.path.join(samples_path, file) for file in os.listdir(samples_path)]
|
564 |
+
files.sort(key=lambda file: os.path.getctime(file), reverse=True)
|
565 |
+
return files
|
566 |
+
except:
|
567 |
+
return []
|
568 |
+
|
569 |
+
def start_training(
|
570 |
+
base_model,
|
571 |
+
lora_name,
|
572 |
+
train_script,
|
573 |
+
train_config,
|
574 |
+
sample_prompts,
|
575 |
+
):
|
576 |
+
# write custom script and toml
|
577 |
+
if not os.path.exists("models"):
|
578 |
+
os.makedirs("models", exist_ok=True)
|
579 |
+
if not os.path.exists("outputs"):
|
580 |
+
os.makedirs("outputs", exist_ok=True)
|
581 |
+
output_name = slugify(lora_name)
|
582 |
+
output_dir = resolve_path_without_quotes(f"outputs/{output_name}")
|
583 |
+
if not os.path.exists(output_dir):
|
584 |
+
os.makedirs(output_dir, exist_ok=True)
|
585 |
+
|
586 |
+
download(base_model)
|
587 |
+
|
588 |
+
file_type = "sh"
|
589 |
+
if sys.platform == "win32":
|
590 |
+
file_type = "bat"
|
591 |
+
|
592 |
+
sh_filename = f"train.{file_type}"
|
593 |
+
sh_filepath = resolve_path_without_quotes(f"outputs/{output_name}/{sh_filename}")
|
594 |
+
with open(sh_filepath, 'w', encoding="utf-8") as file:
|
595 |
+
file.write(train_script)
|
596 |
+
gr.Info(f"Generated train script at {sh_filename}")
|
597 |
+
|
598 |
+
|
599 |
+
dataset_path = resolve_path_without_quotes(f"outputs/{output_name}/dataset.toml")
|
600 |
+
with open(dataset_path, 'w', encoding="utf-8") as file:
|
601 |
+
file.write(train_config)
|
602 |
+
gr.Info(f"Generated dataset.toml")
|
603 |
+
|
604 |
+
sample_prompts_path = resolve_path_without_quotes(f"outputs/{output_name}/sample_prompts.txt")
|
605 |
+
with open(sample_prompts_path, 'w', encoding='utf-8') as file:
|
606 |
+
file.write(sample_prompts)
|
607 |
+
gr.Info(f"Generated sample_prompts.txt")
|
608 |
+
|
609 |
+
# Train
|
610 |
+
if sys.platform == "win32":
|
611 |
+
command = sh_filepath
|
612 |
+
else:
|
613 |
+
command = f"bash \"{sh_filepath}\""
|
614 |
+
|
615 |
+
# Use Popen to run the command and capture output in real-time
|
616 |
+
env = os.environ.copy()
|
617 |
+
env['PYTHONIOENCODING'] = 'utf-8'
|
618 |
+
env['LOG_LEVEL'] = 'DEBUG'
|
619 |
+
runner = LogsViewRunner()
|
620 |
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
621 |
+
gr.Info(f"Started training")
|
622 |
+
yield from runner.run_command([command], cwd=cwd)
|
623 |
+
yield runner.log(f"Runner: {runner}")
|
624 |
+
|
625 |
+
# Generate Readme
|
626 |
+
config = toml.loads(train_config)
|
627 |
+
concept_sentence = config['datasets'][0]['subsets'][0]['class_tokens']
|
628 |
+
print(f"concept_sentence={concept_sentence}")
|
629 |
+
print(f"lora_name {lora_name}, concept_sentence={concept_sentence}, output_name={output_name}")
|
630 |
+
sample_prompts_path = resolve_path_without_quotes(f"outputs/{output_name}/sample_prompts.txt")
|
631 |
+
with open(sample_prompts_path, "r", encoding="utf-8") as f:
|
632 |
+
lines = f.readlines()
|
633 |
+
sample_prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
|
634 |
+
md = readme(base_model, lora_name, concept_sentence, sample_prompts)
|
635 |
+
readme_path = resolve_path_without_quotes(f"outputs/{output_name}/README.md")
|
636 |
+
with open(readme_path, "w", encoding="utf-8") as f:
|
637 |
+
f.write(md)
|
638 |
+
|
639 |
+
gr.Info(f"Training Complete. Check the outputs folder for the LoRA files.", duration=None)
|
640 |
+
|
641 |
+
|
642 |
+
def update(
|
643 |
+
base_model,
|
644 |
+
lora_name,
|
645 |
+
resolution,
|
646 |
+
seed,
|
647 |
+
workers,
|
648 |
+
class_tokens,
|
649 |
+
learning_rate,
|
650 |
+
network_dim,
|
651 |
+
max_train_epochs,
|
652 |
+
save_every_n_epochs,
|
653 |
+
timestep_sampling,
|
654 |
+
guidance_scale,
|
655 |
+
vram,
|
656 |
+
num_repeats,
|
657 |
+
sample_prompts,
|
658 |
+
sample_every_n_steps,
|
659 |
+
*advanced_components,
|
660 |
+
):
|
661 |
+
output_name = slugify(lora_name)
|
662 |
+
dataset_folder = str(f"datasets/{output_name}")
|
663 |
+
sh = gen_sh(
|
664 |
+
base_model,
|
665 |
+
output_name,
|
666 |
+
resolution,
|
667 |
+
seed,
|
668 |
+
workers,
|
669 |
+
learning_rate,
|
670 |
+
network_dim,
|
671 |
+
max_train_epochs,
|
672 |
+
save_every_n_epochs,
|
673 |
+
timestep_sampling,
|
674 |
+
guidance_scale,
|
675 |
+
vram,
|
676 |
+
sample_prompts,
|
677 |
+
sample_every_n_steps,
|
678 |
+
*advanced_components,
|
679 |
+
)
|
680 |
+
toml = gen_toml(
|
681 |
+
dataset_folder,
|
682 |
+
resolution,
|
683 |
+
class_tokens,
|
684 |
+
num_repeats
|
685 |
+
)
|
686 |
+
return gr.update(value=sh), gr.update(value=toml), dataset_folder
|
687 |
+
|
688 |
+
"""
|
689 |
+
demo.load(fn=loaded, js=js, outputs=[hf_token, hf_login, hf_logout, hf_account])
|
690 |
+
"""
|
691 |
+
def loaded():
|
692 |
+
global current_account
|
693 |
+
current_account = account_hf()
|
694 |
+
print(f"current_account={current_account}")
|
695 |
+
if current_account != None:
|
696 |
+
return gr.update(value=current_account["token"]), gr.update(visible=False), gr.update(visible=True), gr.update(value=current_account["account"], visible=True)
|
697 |
+
else:
|
698 |
+
return gr.update(value=""), gr.update(visible=True), gr.update(visible=False), gr.update(value="", visible=False)
|
699 |
+
|
700 |
+
def update_sample(concept_sentence):
|
701 |
+
return gr.update(value=concept_sentence)
|
702 |
+
|
703 |
+
def refresh_publish_tab():
|
704 |
+
loras = get_loras()
|
705 |
+
return gr.Dropdown(label="Trained LoRAs", choices=loras)
|
706 |
+
|
707 |
+
def init_advanced():
|
708 |
+
# if basic_args
|
709 |
+
basic_args = {
|
710 |
+
'pretrained_model_name_or_path',
|
711 |
+
'clip_l',
|
712 |
+
't5xxl',
|
713 |
+
'ae',
|
714 |
+
'cache_latents_to_disk',
|
715 |
+
'save_model_as',
|
716 |
+
'sdpa',
|
717 |
+
'persistent_data_loader_workers',
|
718 |
+
'max_data_loader_n_workers',
|
719 |
+
'seed',
|
720 |
+
'gradient_checkpointing',
|
721 |
+
'mixed_precision',
|
722 |
+
'save_precision',
|
723 |
+
'network_module',
|
724 |
+
'network_dim',
|
725 |
+
'learning_rate',
|
726 |
+
'cache_text_encoder_outputs',
|
727 |
+
'cache_text_encoder_outputs_to_disk',
|
728 |
+
'fp8_base',
|
729 |
+
'highvram',
|
730 |
+
'max_train_epochs',
|
731 |
+
'save_every_n_epochs',
|
732 |
+
'dataset_config',
|
733 |
+
'output_dir',
|
734 |
+
'output_name',
|
735 |
+
'timestep_sampling',
|
736 |
+
'discrete_flow_shift',
|
737 |
+
'model_prediction_type',
|
738 |
+
'guidance_scale',
|
739 |
+
'loss_type',
|
740 |
+
'optimizer_type',
|
741 |
+
'optimizer_args',
|
742 |
+
'lr_scheduler',
|
743 |
+
'sample_prompts',
|
744 |
+
'sample_every_n_steps',
|
745 |
+
'max_grad_norm',
|
746 |
+
'split_mode',
|
747 |
+
'network_args'
|
748 |
+
}
|
749 |
+
|
750 |
+
# generate a UI config
|
751 |
+
# if not in basic_args, create a simple form
|
752 |
+
parser = train_network.setup_parser()
|
753 |
+
flux_train_utils.add_flux_train_arguments(parser)
|
754 |
+
args_info = {}
|
755 |
+
for action in parser._actions:
|
756 |
+
if action.dest != 'help': # Skip the default help argument
|
757 |
+
# if the dest is included in basic_args
|
758 |
+
args_info[action.dest] = {
|
759 |
+
"action": action.option_strings, # Option strings like '--use_8bit_adam'
|
760 |
+
"type": action.type, # Type of the argument
|
761 |
+
"help": action.help, # Help message
|
762 |
+
"default": action.default, # Default value, if any
|
763 |
+
"required": action.required # Whether the argument is required
|
764 |
+
}
|
765 |
+
temp = []
|
766 |
+
for key in args_info:
|
767 |
+
temp.append({ 'key': key, 'action': args_info[key] })
|
768 |
+
temp.sort(key=lambda x: x['key'])
|
769 |
+
advanced_component_ids = []
|
770 |
+
advanced_components = []
|
771 |
+
for item in temp:
|
772 |
+
key = item['key']
|
773 |
+
action = item['action']
|
774 |
+
if key in basic_args:
|
775 |
+
print("")
|
776 |
+
else:
|
777 |
+
action_type = str(action['type'])
|
778 |
+
component = None
|
779 |
+
with gr.Column(min_width=300):
|
780 |
+
if action_type == "None":
|
781 |
+
# radio
|
782 |
+
component = gr.Checkbox()
|
783 |
+
# elif action_type == "<class 'str'>":
|
784 |
+
# component = gr.Textbox()
|
785 |
+
# elif action_type == "<class 'int'>":
|
786 |
+
# component = gr.Number(precision=0)
|
787 |
+
# elif action_type == "<class 'float'>":
|
788 |
+
# component = gr.Number()
|
789 |
+
# elif "int_or_float" in action_type:
|
790 |
+
# component = gr.Number()
|
791 |
+
else:
|
792 |
+
component = gr.Textbox(value="")
|
793 |
+
if component != None:
|
794 |
+
component.interactive = True
|
795 |
+
component.elem_id = action['action'][0]
|
796 |
+
component.label = component.elem_id
|
797 |
+
component.elem_classes = ["advanced"]
|
798 |
+
if action['help'] != None:
|
799 |
+
component.info = action['help']
|
800 |
+
advanced_components.append(component)
|
801 |
+
advanced_component_ids.append(component.elem_id)
|
802 |
+
return advanced_components, advanced_component_ids
|
803 |
+
|
804 |
+
|
805 |
+
theme = gr.themes.Monochrome(
|
806 |
+
text_size=gr.themes.Size(lg="18px", md="15px", sm="13px", xl="22px", xs="12px", xxl="24px", xxs="9px"),
|
807 |
+
font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif"],
|
808 |
+
)
|
809 |
+
css = """
|
810 |
+
@keyframes rotate {
|
811 |
+
0% {
|
812 |
+
transform: rotate(0deg);
|
813 |
+
}
|
814 |
+
100% {
|
815 |
+
transform: rotate(360deg);
|
816 |
+
}
|
817 |
+
}
|
818 |
+
#advanced_options .advanced:nth-child(even) { background: rgba(0,0,100,0.04) !important; }
|
819 |
+
h1{font-family: georgia; font-style: italic; font-weight: bold; font-size: 30px; letter-spacing: -1px;}
|
820 |
+
h3{margin-top: 0}
|
821 |
+
.tabitem{border: 0px}
|
822 |
+
.group_padding{}
|
823 |
+
nav{position: fixed; top: 0; left: 0; right: 0; z-index: 1000; text-align: center; padding: 10px; box-sizing: border-box; display: flex; align-items: center; backdrop-filter: blur(10px); }
|
824 |
+
nav button { background: none; color: firebrick; font-weight: bold; border: 2px solid firebrick; padding: 5px 10px; border-radius: 5px; font-size: 14px; }
|
825 |
+
nav img { height: 40px; width: 40px; border-radius: 40px; }
|
826 |
+
nav img.rotate { animation: rotate 2s linear infinite; }
|
827 |
+
.flexible { flex-grow: 1; }
|
828 |
+
.tast-details { margin: 10px 0 !important; }
|
829 |
+
.toast-wrap { bottom: var(--size-4) !important; top: auto !important; border: none !important; backdrop-filter: blur(10px); }
|
830 |
+
.toast-title, .toast-text, .toast-icon, .toast-close { color: black !important; font-size: 14px; }
|
831 |
+
.toast-body { border: none !important; }
|
832 |
+
#terminal { box-shadow: none !important; margin-bottom: 25px; background: rgba(0,0,0,0.03); }
|
833 |
+
#terminal .generating { border: none !important; }
|
834 |
+
#terminal label { position: absolute !important; }
|
835 |
+
.tabs { margin-top: 50px; }
|
836 |
+
.hidden { display: none !important; }
|
837 |
+
.codemirror-wrapper .cm-line { font-size: 12px !important; }
|
838 |
+
label { font-weight: bold !important; }
|
839 |
+
#start_training.clicked { background: silver; color: black; }
|
840 |
+
"""
|
841 |
+
|
842 |
+
js = """
|
843 |
+
function() {
|
844 |
+
let autoscroll = document.querySelector("#autoscroll")
|
845 |
+
if (window.iidxx) {
|
846 |
+
window.clearInterval(window.iidxx);
|
847 |
+
}
|
848 |
+
window.iidxx = window.setInterval(function() {
|
849 |
+
let text=document.querySelector(".codemirror-wrapper .cm-line").innerText.trim()
|
850 |
+
let img = document.querySelector("#logo")
|
851 |
+
if (text.length > 0) {
|
852 |
+
autoscroll.classList.remove("hidden")
|
853 |
+
if (autoscroll.classList.contains("on")) {
|
854 |
+
autoscroll.textContent = "Autoscroll ON"
|
855 |
+
window.scrollTo(0, document.body.scrollHeight, { behavior: "smooth" });
|
856 |
+
img.classList.add("rotate")
|
857 |
+
} else {
|
858 |
+
autoscroll.textContent = "Autoscroll OFF"
|
859 |
+
img.classList.remove("rotate")
|
860 |
+
}
|
861 |
+
}
|
862 |
+
}, 500);
|
863 |
+
console.log("autoscroll", autoscroll)
|
864 |
+
autoscroll.addEventListener("click", (e) => {
|
865 |
+
autoscroll.classList.toggle("on")
|
866 |
+
})
|
867 |
+
function debounce(fn, delay) {
|
868 |
+
let timeoutId;
|
869 |
+
return function(...args) {
|
870 |
+
clearTimeout(timeoutId);
|
871 |
+
timeoutId = setTimeout(() => fn(...args), delay);
|
872 |
+
};
|
873 |
+
}
|
874 |
+
|
875 |
+
function handleClick() {
|
876 |
+
console.log("refresh")
|
877 |
+
document.querySelector("#refresh").click();
|
878 |
+
}
|
879 |
+
const debouncedClick = debounce(handleClick, 1000);
|
880 |
+
document.addEventListener("input", debouncedClick);
|
881 |
+
|
882 |
+
document.querySelector("#start_training").addEventListener("click", (e) => {
|
883 |
+
e.target.classList.add("clicked")
|
884 |
+
e.target.innerHTML = "Training..."
|
885 |
+
})
|
886 |
+
|
887 |
+
}
|
888 |
+
"""
|
889 |
+
|
890 |
+
current_account = account_hf()
|
891 |
+
print(f"current_account={current_account}")
|
892 |
+
|
893 |
+
with gr.Blocks(elem_id="app", theme=theme, css=css, fill_width=True) as demo:
|
894 |
+
with gr.Tabs() as tabs:
|
895 |
+
with gr.TabItem("Gym"):
|
896 |
+
output_components = []
|
897 |
+
with gr.Row():
|
898 |
+
gr.HTML("""<nav>
|
899 |
+
<img id='logo' src='/file=icon.png' width='80' height='80'>
|
900 |
+
<div class='flexible'></div>
|
901 |
+
<button id='autoscroll' class='on hidden'></button>
|
902 |
+
</nav>
|
903 |
+
""")
|
904 |
+
with gr.Row(elem_id='container'):
|
905 |
+
with gr.Column():
|
906 |
+
gr.Markdown(
|
907 |
+
"""# Step 1. LoRA Info
|
908 |
+
<p style="margin-top:0">Configure your LoRA train settings.</p>
|
909 |
+
""", elem_classes="group_padding")
|
910 |
+
lora_name = gr.Textbox(
|
911 |
+
label="The name of your LoRA",
|
912 |
+
info="This has to be a unique name",
|
913 |
+
placeholder="e.g.: Persian Miniature Painting style, Cat Toy",
|
914 |
+
)
|
915 |
+
concept_sentence = gr.Textbox(
|
916 |
+
elem_id="--concept_sentence",
|
917 |
+
label="Trigger word/sentence",
|
918 |
+
info="Trigger word or sentence to be used",
|
919 |
+
placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'",
|
920 |
+
interactive=True,
|
921 |
+
)
|
922 |
+
model_names = list(models.keys())
|
923 |
+
print(f"model_names={model_names}")
|
924 |
+
base_model = gr.Dropdown(label="Base model (edit the models.yaml file to add more to this list)", choices=model_names, value=model_names[0])
|
925 |
+
vram = gr.Radio(["20G", "16G", "12G" ], value="20G", label="VRAM", interactive=True)
|
926 |
+
num_repeats = gr.Number(value=10, precision=0, label="Repeat trains per image", interactive=True)
|
927 |
+
max_train_epochs = gr.Number(label="Max Train Epochs", value=16, interactive=True)
|
928 |
+
total_steps = gr.Number(0, interactive=False, label="Expected training steps")
|
929 |
+
sample_prompts = gr.Textbox("", lines=5, label="Sample Image Prompts (Separate with new lines)", interactive=True)
|
930 |
+
sample_every_n_steps = gr.Number(0, precision=0, label="Sample Image Every N Steps", interactive=True)
|
931 |
+
resolution = gr.Number(value=512, precision=0, label="Resize dataset images")
|
932 |
+
with gr.Column():
|
933 |
+
gr.Markdown(
|
934 |
+
"""# Step 2. Dataset
|
935 |
+
<p style="margin-top:0">Make sure the captions include the trigger word.</p>
|
936 |
+
""", elem_classes="group_padding")
|
937 |
+
with gr.Group():
|
938 |
+
images = gr.File(
|
939 |
+
file_types=["image", ".txt"],
|
940 |
+
label="Upload your images",
|
941 |
+
#info="If you want, you can also manually upload caption files that match the image names (example: img0.png => img0.txt)",
|
942 |
+
file_count="multiple",
|
943 |
+
interactive=True,
|
944 |
+
visible=True,
|
945 |
+
scale=1,
|
946 |
+
)
|
947 |
+
with gr.Group(visible=False) as captioning_area:
|
948 |
+
do_captioning = gr.Button("Add AI captions with Florence-2")
|
949 |
+
output_components.append(captioning_area)
|
950 |
+
#output_components = [captioning_area]
|
951 |
+
caption_list = []
|
952 |
+
for i in range(1, MAX_IMAGES + 1):
|
953 |
+
locals()[f"captioning_row_{i}"] = gr.Row(visible=False)
|
954 |
+
with locals()[f"captioning_row_{i}"]:
|
955 |
+
locals()[f"image_{i}"] = gr.Image(
|
956 |
+
type="filepath",
|
957 |
+
width=111,
|
958 |
+
height=111,
|
959 |
+
min_width=111,
|
960 |
+
interactive=False,
|
961 |
+
scale=2,
|
962 |
+
show_label=False,
|
963 |
+
show_share_button=False,
|
964 |
+
show_download_button=False,
|
965 |
+
)
|
966 |
+
locals()[f"caption_{i}"] = gr.Textbox(
|
967 |
+
label=f"Caption {i}", scale=15, interactive=True
|
968 |
+
)
|
969 |
+
|
970 |
+
output_components.append(locals()[f"captioning_row_{i}"])
|
971 |
+
output_components.append(locals()[f"image_{i}"])
|
972 |
+
output_components.append(locals()[f"caption_{i}"])
|
973 |
+
caption_list.append(locals()[f"caption_{i}"])
|
974 |
+
with gr.Column():
|
975 |
+
gr.Markdown(
|
976 |
+
"""# Step 3. Train
|
977 |
+
<p style="margin-top:0">Press start to start training.</p>
|
978 |
+
""", elem_classes="group_padding")
|
979 |
+
refresh = gr.Button("Refresh", elem_id="refresh", visible=False)
|
980 |
+
start = gr.Button("Start training", visible=False, elem_id="start_training")
|
981 |
+
output_components.append(start)
|
982 |
+
train_script = gr.Textbox(label="Train script", max_lines=100, interactive=True)
|
983 |
+
train_config = gr.Textbox(label="Train config", max_lines=100, interactive=True)
|
984 |
+
with gr.Accordion("Advanced options", elem_id='advanced_options', open=False):
|
985 |
+
with gr.Row():
|
986 |
+
with gr.Column(min_width=300):
|
987 |
+
seed = gr.Number(label="--seed", info="Seed", value=42, interactive=True)
|
988 |
+
with gr.Column(min_width=300):
|
989 |
+
workers = gr.Number(label="--max_data_loader_n_workers", info="Number of Workers", value=2, interactive=True)
|
990 |
+
with gr.Column(min_width=300):
|
991 |
+
learning_rate = gr.Textbox(label="--learning_rate", info="Learning Rate", value="8e-4", interactive=True)
|
992 |
+
with gr.Column(min_width=300):
|
993 |
+
save_every_n_epochs = gr.Number(label="--save_every_n_epochs", info="Save every N epochs", value=4, interactive=True)
|
994 |
+
with gr.Column(min_width=300):
|
995 |
+
guidance_scale = gr.Number(label="--guidance_scale", info="Guidance Scale", value=1.0, interactive=True)
|
996 |
+
with gr.Column(min_width=300):
|
997 |
+
timestep_sampling = gr.Textbox(label="--timestep_sampling", info="Timestep Sampling", value="shift", interactive=True)
|
998 |
+
with gr.Column(min_width=300):
|
999 |
+
network_dim = gr.Number(label="--network_dim", info="LoRA Rank", value=4, minimum=4, maximum=128, step=4, interactive=True)
|
1000 |
+
advanced_components, advanced_component_ids = init_advanced()
|
1001 |
+
with gr.Row():
|
1002 |
+
terminal = LogsView(label="Train log", elem_id="terminal")
|
1003 |
+
with gr.Row():
|
1004 |
+
gallery = gr.Gallery(get_samples, inputs=[lora_name], label="Samples", every=10, columns=6)
|
1005 |
+
|
1006 |
+
with gr.TabItem("Publish") as publish_tab:
|
1007 |
+
hf_token = gr.Textbox(label="Huggingface Token")
|
1008 |
+
hf_login = gr.Button("Login")
|
1009 |
+
hf_logout = gr.Button("Logout")
|
1010 |
+
with gr.Row() as row:
|
1011 |
+
gr.Markdown("**LoRA**")
|
1012 |
+
gr.Markdown("**Upload**")
|
1013 |
+
loras = get_loras()
|
1014 |
+
with gr.Row():
|
1015 |
+
lora_rows = refresh_publish_tab()
|
1016 |
+
with gr.Column():
|
1017 |
+
with gr.Row():
|
1018 |
+
repo_owner = gr.Textbox(label="Account", interactive=False)
|
1019 |
+
repo_name = gr.Textbox(label="Repository Name")
|
1020 |
+
repo_visibility = gr.Textbox(label="Repository Visibility ('public' or 'private')", value="public")
|
1021 |
+
upload_button = gr.Button("Upload to HuggingFace")
|
1022 |
+
upload_button.click(
|
1023 |
+
fn=upload_hf,
|
1024 |
+
inputs=[
|
1025 |
+
base_model,
|
1026 |
+
lora_rows,
|
1027 |
+
repo_owner,
|
1028 |
+
repo_name,
|
1029 |
+
repo_visibility,
|
1030 |
+
hf_token,
|
1031 |
+
]
|
1032 |
+
)
|
1033 |
+
hf_login.click(fn=login_hf, inputs=[hf_token], outputs=[hf_token, hf_login, hf_logout, repo_owner])
|
1034 |
+
hf_logout.click(fn=logout_hf, outputs=[hf_token, hf_login, hf_logout, repo_owner])
|
1035 |
+
|
1036 |
+
|
1037 |
+
publish_tab.select(refresh_publish_tab, outputs=lora_rows)
|
1038 |
+
lora_rows.select(fn=set_repo, inputs=[lora_rows], outputs=[repo_name])
|
1039 |
+
|
1040 |
+
dataset_folder = gr.State()
|
1041 |
+
|
1042 |
+
listeners = [
|
1043 |
+
base_model,
|
1044 |
+
lora_name,
|
1045 |
+
resolution,
|
1046 |
+
seed,
|
1047 |
+
workers,
|
1048 |
+
concept_sentence,
|
1049 |
+
learning_rate,
|
1050 |
+
network_dim,
|
1051 |
+
max_train_epochs,
|
1052 |
+
save_every_n_epochs,
|
1053 |
+
timestep_sampling,
|
1054 |
+
guidance_scale,
|
1055 |
+
vram,
|
1056 |
+
num_repeats,
|
1057 |
+
sample_prompts,
|
1058 |
+
sample_every_n_steps,
|
1059 |
+
*advanced_components
|
1060 |
+
]
|
1061 |
+
advanced_component_ids = [x.elem_id for x in advanced_components]
|
1062 |
+
original_advanced_component_values = [comp.value for comp in advanced_components]
|
1063 |
+
images.upload(
|
1064 |
+
load_captioning,
|
1065 |
+
inputs=[images, concept_sentence],
|
1066 |
+
outputs=output_components
|
1067 |
+
)
|
1068 |
+
images.delete(
|
1069 |
+
load_captioning,
|
1070 |
+
inputs=[images, concept_sentence],
|
1071 |
+
outputs=output_components
|
1072 |
+
)
|
1073 |
+
images.clear(
|
1074 |
+
hide_captioning,
|
1075 |
+
outputs=[captioning_area, start]
|
1076 |
+
)
|
1077 |
+
max_train_epochs.change(
|
1078 |
+
fn=update_total_steps,
|
1079 |
+
inputs=[max_train_epochs, num_repeats, images],
|
1080 |
+
outputs=[total_steps]
|
1081 |
+
)
|
1082 |
+
num_repeats.change(
|
1083 |
+
fn=update_total_steps,
|
1084 |
+
inputs=[max_train_epochs, num_repeats, images],
|
1085 |
+
outputs=[total_steps]
|
1086 |
+
)
|
1087 |
+
images.upload(
|
1088 |
+
fn=update_total_steps,
|
1089 |
+
inputs=[max_train_epochs, num_repeats, images],
|
1090 |
+
outputs=[total_steps]
|
1091 |
+
)
|
1092 |
+
images.delete(
|
1093 |
+
fn=update_total_steps,
|
1094 |
+
inputs=[max_train_epochs, num_repeats, images],
|
1095 |
+
outputs=[total_steps]
|
1096 |
+
)
|
1097 |
+
images.clear(
|
1098 |
+
fn=update_total_steps,
|
1099 |
+
inputs=[max_train_epochs, num_repeats, images],
|
1100 |
+
outputs=[total_steps]
|
1101 |
+
)
|
1102 |
+
concept_sentence.change(fn=update_sample, inputs=[concept_sentence], outputs=sample_prompts)
|
1103 |
+
start.click(fn=create_dataset, inputs=[dataset_folder, resolution, images] + caption_list, outputs=dataset_folder).then(
|
1104 |
+
fn=start_training,
|
1105 |
+
inputs=[
|
1106 |
+
base_model,
|
1107 |
+
lora_name,
|
1108 |
+
train_script,
|
1109 |
+
train_config,
|
1110 |
+
sample_prompts,
|
1111 |
+
],
|
1112 |
+
outputs=terminal,
|
1113 |
+
)
|
1114 |
+
do_captioning.click(fn=run_captioning, inputs=[images, concept_sentence] + caption_list, outputs=caption_list)
|
1115 |
+
demo.load(fn=loaded, js=js, outputs=[hf_token, hf_login, hf_logout, repo_owner])
|
1116 |
+
refresh.click(update, inputs=listeners, outputs=[train_script, train_config, dataset_folder])
|
1117 |
+
if __name__ == "__main__":
|
1118 |
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
1119 |
+
demo.launch(debug=True, show_error=True, allowed_paths=[cwd])
|
datasets/1
ADDED
File without changes
|
docker-compose.yml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
services:
|
2 |
+
|
3 |
+
fluxgym:
|
4 |
+
build:
|
5 |
+
context: .
|
6 |
+
# change the dockerfile to Dockerfile.cuda12.4 if you are running CUDA 12.4 drivers otherwise leave as is
|
7 |
+
dockerfile: Dockerfile
|
8 |
+
image: fluxgym
|
9 |
+
container_name: fluxgym
|
10 |
+
ports:
|
11 |
+
- 7860:7860
|
12 |
+
environment:
|
13 |
+
- PUID=${PUID:-1000}
|
14 |
+
- PGID=${PGID:-1000}
|
15 |
+
volumes:
|
16 |
+
- /etc/localtime:/etc/localtime:ro
|
17 |
+
- /etc/timezone:/etc/timezone:ro
|
18 |
+
- ./:/app/fluxgym
|
19 |
+
stop_signal: SIGKILL
|
20 |
+
tty: true
|
21 |
+
deploy:
|
22 |
+
resources:
|
23 |
+
reservations:
|
24 |
+
devices:
|
25 |
+
- driver: nvidia
|
26 |
+
count: all
|
27 |
+
capabilities: [gpu]
|
28 |
+
restart: unless-stopped
|
fine_tune.py
ADDED
@@ -0,0 +1,560 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# training with captions
|
2 |
+
# XXX dropped option: hypernetwork training
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
from multiprocessing import Value
|
8 |
+
import toml
|
9 |
+
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from library import deepspeed_utils, strategy_base
|
14 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
15 |
+
|
16 |
+
init_ipex()
|
17 |
+
|
18 |
+
from accelerate.utils import set_seed
|
19 |
+
from diffusers import DDPMScheduler
|
20 |
+
|
21 |
+
from .utils import setup_logging, add_logging_arguments
|
22 |
+
|
23 |
+
setup_logging()
|
24 |
+
import logging
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
import library.train_util as train_util
|
29 |
+
import library.config_util as config_util
|
30 |
+
from library.config_util import (
|
31 |
+
ConfigSanitizer,
|
32 |
+
BlueprintGenerator,
|
33 |
+
)
|
34 |
+
import library.custom_train_functions as custom_train_functions
|
35 |
+
from library.custom_train_functions import (
|
36 |
+
apply_snr_weight,
|
37 |
+
get_weighted_text_embeddings,
|
38 |
+
prepare_scheduler_for_custom_training,
|
39 |
+
scale_v_prediction_loss_like_noise_prediction,
|
40 |
+
apply_debiased_estimation,
|
41 |
+
)
|
42 |
+
import library.strategy_sd as strategy_sd
|
43 |
+
|
44 |
+
|
45 |
+
def train(args):
|
46 |
+
train_util.verify_training_args(args)
|
47 |
+
train_util.prepare_dataset_args(args, True)
|
48 |
+
deepspeed_utils.prepare_deepspeed_args(args)
|
49 |
+
setup_logging(args, reset=True)
|
50 |
+
|
51 |
+
cache_latents = args.cache_latents
|
52 |
+
|
53 |
+
if args.seed is not None:
|
54 |
+
set_seed(args.seed) # 乱数系列を初期化する
|
55 |
+
|
56 |
+
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
|
57 |
+
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
58 |
+
|
59 |
+
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
60 |
+
if cache_latents:
|
61 |
+
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
62 |
+
False, args.cache_latents_to_disk, args.vae_batch_size, False
|
63 |
+
)
|
64 |
+
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
65 |
+
|
66 |
+
# データセットを準備する
|
67 |
+
if args.dataset_class is None:
|
68 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True))
|
69 |
+
if args.dataset_config is not None:
|
70 |
+
logger.info(f"Load dataset config from {args.dataset_config}")
|
71 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
72 |
+
ignored = ["train_data_dir", "in_json"]
|
73 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
74 |
+
logger.warning(
|
75 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
76 |
+
", ".join(ignored)
|
77 |
+
)
|
78 |
+
)
|
79 |
+
else:
|
80 |
+
user_config = {
|
81 |
+
"datasets": [
|
82 |
+
{
|
83 |
+
"subsets": [
|
84 |
+
{
|
85 |
+
"image_dir": args.train_data_dir,
|
86 |
+
"metadata_file": args.in_json,
|
87 |
+
}
|
88 |
+
]
|
89 |
+
}
|
90 |
+
]
|
91 |
+
}
|
92 |
+
|
93 |
+
blueprint = blueprint_generator.generate(user_config, args)
|
94 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
95 |
+
else:
|
96 |
+
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
97 |
+
|
98 |
+
current_epoch = Value("i", 0)
|
99 |
+
current_step = Value("i", 0)
|
100 |
+
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
101 |
+
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
102 |
+
|
103 |
+
if args.debug_dataset:
|
104 |
+
train_util.debug_dataset(train_dataset_group)
|
105 |
+
return
|
106 |
+
if len(train_dataset_group) == 0:
|
107 |
+
logger.error(
|
108 |
+
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
|
109 |
+
)
|
110 |
+
return
|
111 |
+
|
112 |
+
if cache_latents:
|
113 |
+
assert (
|
114 |
+
train_dataset_group.is_latent_cacheable()
|
115 |
+
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
116 |
+
|
117 |
+
# acceleratorを準備する
|
118 |
+
logger.info("prepare accelerator")
|
119 |
+
accelerator = train_util.prepare_accelerator(args)
|
120 |
+
|
121 |
+
# mixed precisionに対応した型を用意しておき適宜castする
|
122 |
+
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
123 |
+
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
124 |
+
|
125 |
+
# モデルを読み込む
|
126 |
+
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
|
127 |
+
|
128 |
+
# verify load/save model formats
|
129 |
+
if load_stable_diffusion_format:
|
130 |
+
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
|
131 |
+
src_diffusers_model_path = None
|
132 |
+
else:
|
133 |
+
src_stable_diffusion_ckpt = None
|
134 |
+
src_diffusers_model_path = args.pretrained_model_name_or_path
|
135 |
+
|
136 |
+
if args.save_model_as is None:
|
137 |
+
save_stable_diffusion_format = load_stable_diffusion_format
|
138 |
+
use_safetensors = args.use_safetensors
|
139 |
+
else:
|
140 |
+
save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
|
141 |
+
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
|
142 |
+
|
143 |
+
# Diffusers版のxformers使用フラグを設定する関数
|
144 |
+
def set_diffusers_xformers_flag(model, valid):
|
145 |
+
# model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう
|
146 |
+
# pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`)
|
147 |
+
# U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか
|
148 |
+
# 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^)
|
149 |
+
|
150 |
+
# Recursively walk through all the children.
|
151 |
+
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
152 |
+
# gets the message
|
153 |
+
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
154 |
+
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
155 |
+
module.set_use_memory_efficient_attention_xformers(valid)
|
156 |
+
|
157 |
+
for child in module.children():
|
158 |
+
fn_recursive_set_mem_eff(child)
|
159 |
+
|
160 |
+
fn_recursive_set_mem_eff(model)
|
161 |
+
|
162 |
+
# モデルに xformers とか memory efficient attention を組み込む
|
163 |
+
if args.diffusers_xformers:
|
164 |
+
accelerator.print("Use xformers by Diffusers")
|
165 |
+
set_diffusers_xformers_flag(unet, True)
|
166 |
+
else:
|
167 |
+
# Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
|
168 |
+
accelerator.print("Disable Diffusers' xformers")
|
169 |
+
set_diffusers_xformers_flag(unet, False)
|
170 |
+
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
171 |
+
|
172 |
+
# 学習を準備する
|
173 |
+
if cache_latents:
|
174 |
+
vae.to(accelerator.device, dtype=vae_dtype)
|
175 |
+
vae.requires_grad_(False)
|
176 |
+
vae.eval()
|
177 |
+
|
178 |
+
train_dataset_group.new_cache_latents(vae, accelerator)
|
179 |
+
|
180 |
+
vae.to("cpu")
|
181 |
+
clean_memory_on_device(accelerator.device)
|
182 |
+
|
183 |
+
accelerator.wait_for_everyone()
|
184 |
+
|
185 |
+
# 学習を準備する:モデルを適切な状態にする
|
186 |
+
training_models = []
|
187 |
+
if args.gradient_checkpointing:
|
188 |
+
unet.enable_gradient_checkpointing()
|
189 |
+
training_models.append(unet)
|
190 |
+
|
191 |
+
if args.train_text_encoder:
|
192 |
+
accelerator.print("enable text encoder training")
|
193 |
+
if args.gradient_checkpointing:
|
194 |
+
text_encoder.gradient_checkpointing_enable()
|
195 |
+
training_models.append(text_encoder)
|
196 |
+
else:
|
197 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
198 |
+
text_encoder.requires_grad_(False) # text encoderは学習しない
|
199 |
+
if args.gradient_checkpointing:
|
200 |
+
text_encoder.gradient_checkpointing_enable()
|
201 |
+
text_encoder.train() # required for gradient_checkpointing
|
202 |
+
else:
|
203 |
+
text_encoder.eval()
|
204 |
+
|
205 |
+
text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
|
206 |
+
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
207 |
+
|
208 |
+
if not cache_latents:
|
209 |
+
vae.requires_grad_(False)
|
210 |
+
vae.eval()
|
211 |
+
vae.to(accelerator.device, dtype=vae_dtype)
|
212 |
+
|
213 |
+
for m in training_models:
|
214 |
+
m.requires_grad_(True)
|
215 |
+
|
216 |
+
trainable_params = []
|
217 |
+
if args.learning_rate_te is None or not args.train_text_encoder:
|
218 |
+
for m in training_models:
|
219 |
+
trainable_params.extend(m.parameters())
|
220 |
+
else:
|
221 |
+
trainable_params = [
|
222 |
+
{"params": list(unet.parameters()), "lr": args.learning_rate},
|
223 |
+
{"params": list(text_encoder.parameters()), "lr": args.learning_rate_te},
|
224 |
+
]
|
225 |
+
|
226 |
+
# 学習に必要なクラスを準備する
|
227 |
+
accelerator.print("prepare optimizer, data loader etc.")
|
228 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)
|
229 |
+
|
230 |
+
# prepare dataloader
|
231 |
+
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
232 |
+
# some strategies can be None
|
233 |
+
train_dataset_group.set_current_strategies()
|
234 |
+
|
235 |
+
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
236 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
237 |
+
train_dataloader = torch.utils.data.DataLoader(
|
238 |
+
train_dataset_group,
|
239 |
+
batch_size=1,
|
240 |
+
shuffle=True,
|
241 |
+
collate_fn=collator,
|
242 |
+
num_workers=n_workers,
|
243 |
+
persistent_workers=args.persistent_data_loader_workers,
|
244 |
+
)
|
245 |
+
|
246 |
+
# 学習ステップ数を計算する
|
247 |
+
if args.max_train_epochs is not None:
|
248 |
+
args.max_train_steps = args.max_train_epochs * math.ceil(
|
249 |
+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
250 |
+
)
|
251 |
+
accelerator.print(
|
252 |
+
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
253 |
+
)
|
254 |
+
|
255 |
+
# データセット側にも学習ステップを送信
|
256 |
+
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
257 |
+
|
258 |
+
# lr schedulerを用意する
|
259 |
+
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
260 |
+
|
261 |
+
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
262 |
+
if args.full_fp16:
|
263 |
+
assert (
|
264 |
+
args.mixed_precision == "fp16"
|
265 |
+
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
266 |
+
accelerator.print("enable full fp16 training.")
|
267 |
+
unet.to(weight_dtype)
|
268 |
+
text_encoder.to(weight_dtype)
|
269 |
+
|
270 |
+
if args.deepspeed:
|
271 |
+
if args.train_text_encoder:
|
272 |
+
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
|
273 |
+
else:
|
274 |
+
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
|
275 |
+
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
276 |
+
ds_model, optimizer, train_dataloader, lr_scheduler
|
277 |
+
)
|
278 |
+
training_models = [ds_model]
|
279 |
+
else:
|
280 |
+
# acceleratorがなんかよろしくやってくれるらしい
|
281 |
+
if args.train_text_encoder:
|
282 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
283 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
284 |
+
)
|
285 |
+
else:
|
286 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
287 |
+
|
288 |
+
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
289 |
+
if args.full_fp16:
|
290 |
+
train_util.patch_accelerator_for_fp16_training(accelerator)
|
291 |
+
|
292 |
+
# resumeする
|
293 |
+
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
294 |
+
|
295 |
+
# epoch数を計算する
|
296 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
297 |
+
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
298 |
+
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
299 |
+
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
300 |
+
|
301 |
+
# 学習する
|
302 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
303 |
+
accelerator.print("running training / 学習開始")
|
304 |
+
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
305 |
+
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
306 |
+
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
307 |
+
accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
308 |
+
accelerator.print(
|
309 |
+
f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
|
310 |
+
)
|
311 |
+
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
312 |
+
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
313 |
+
|
314 |
+
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
315 |
+
global_step = 0
|
316 |
+
|
317 |
+
noise_scheduler = DDPMScheduler(
|
318 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
319 |
+
)
|
320 |
+
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
321 |
+
if args.zero_terminal_snr:
|
322 |
+
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
323 |
+
|
324 |
+
if accelerator.is_main_process:
|
325 |
+
init_kwargs = {}
|
326 |
+
if args.wandb_run_name:
|
327 |
+
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
328 |
+
if args.log_tracker_config is not None:
|
329 |
+
init_kwargs = toml.load(args.log_tracker_config)
|
330 |
+
accelerator.init_trackers(
|
331 |
+
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
|
332 |
+
config=train_util.get_sanitized_config_or_none(args),
|
333 |
+
init_kwargs=init_kwargs,
|
334 |
+
)
|
335 |
+
|
336 |
+
# For --sample_at_first
|
337 |
+
train_util.sample_images(
|
338 |
+
accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
|
339 |
+
)
|
340 |
+
|
341 |
+
loss_recorder = train_util.LossRecorder()
|
342 |
+
for epoch in range(num_train_epochs):
|
343 |
+
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
344 |
+
current_epoch.value = epoch + 1
|
345 |
+
|
346 |
+
for m in training_models:
|
347 |
+
m.train()
|
348 |
+
|
349 |
+
for step, batch in enumerate(train_dataloader):
|
350 |
+
current_step.value = global_step
|
351 |
+
with accelerator.accumulate(*training_models):
|
352 |
+
with torch.no_grad():
|
353 |
+
if "latents" in batch and batch["latents"] is not None:
|
354 |
+
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
355 |
+
else:
|
356 |
+
# latentに変換
|
357 |
+
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(weight_dtype)
|
358 |
+
latents = latents * 0.18215
|
359 |
+
b_size = latents.shape[0]
|
360 |
+
|
361 |
+
with torch.set_grad_enabled(args.train_text_encoder):
|
362 |
+
# Get the text embedding for conditioning
|
363 |
+
if args.weighted_captions:
|
364 |
+
# TODO move to strategy_sd.py
|
365 |
+
encoder_hidden_states = get_weighted_text_embeddings(
|
366 |
+
tokenize_strategy.tokenizer,
|
367 |
+
text_encoder,
|
368 |
+
batch["captions"],
|
369 |
+
accelerator.device,
|
370 |
+
args.max_token_length // 75 if args.max_token_length else 1,
|
371 |
+
clip_skip=args.clip_skip,
|
372 |
+
)
|
373 |
+
else:
|
374 |
+
input_ids = batch["input_ids_list"][0].to(accelerator.device)
|
375 |
+
encoder_hidden_states = text_encoding_strategy.encode_tokens(
|
376 |
+
tokenize_strategy, [text_encoder], [input_ids]
|
377 |
+
)[0]
|
378 |
+
if args.full_fp16:
|
379 |
+
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
|
380 |
+
|
381 |
+
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
382 |
+
# with noise offset and/or multires noise if specified
|
383 |
+
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
384 |
+
args, noise_scheduler, latents
|
385 |
+
)
|
386 |
+
|
387 |
+
# Predict the noise residual
|
388 |
+
with accelerator.autocast():
|
389 |
+
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
390 |
+
|
391 |
+
if args.v_parameterization:
|
392 |
+
# v-parameterization training
|
393 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
394 |
+
else:
|
395 |
+
target = noise
|
396 |
+
|
397 |
+
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
|
398 |
+
# do not mean over batch dimension for snr weight or scale v-pred loss
|
399 |
+
loss = train_util.conditional_loss(
|
400 |
+
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
401 |
+
)
|
402 |
+
loss = loss.mean([1, 2, 3])
|
403 |
+
|
404 |
+
if args.min_snr_gamma:
|
405 |
+
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
406 |
+
if args.scale_v_pred_loss_like_noise_pred:
|
407 |
+
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
408 |
+
if args.debiased_estimation_loss:
|
409 |
+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
410 |
+
|
411 |
+
loss = loss.mean() # mean over batch dimension
|
412 |
+
else:
|
413 |
+
loss = train_util.conditional_loss(
|
414 |
+
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
|
415 |
+
)
|
416 |
+
|
417 |
+
accelerator.backward(loss)
|
418 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
419 |
+
params_to_clip = []
|
420 |
+
for m in training_models:
|
421 |
+
params_to_clip.extend(m.parameters())
|
422 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
423 |
+
|
424 |
+
optimizer.step()
|
425 |
+
lr_scheduler.step()
|
426 |
+
optimizer.zero_grad(set_to_none=True)
|
427 |
+
|
428 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
429 |
+
if accelerator.sync_gradients:
|
430 |
+
progress_bar.update(1)
|
431 |
+
global_step += 1
|
432 |
+
|
433 |
+
train_util.sample_images(
|
434 |
+
accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
|
435 |
+
)
|
436 |
+
|
437 |
+
# 指定ステップごとにモデルを保存
|
438 |
+
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
439 |
+
accelerator.wait_for_everyone()
|
440 |
+
if accelerator.is_main_process:
|
441 |
+
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
442 |
+
train_util.save_sd_model_on_epoch_end_or_stepwise(
|
443 |
+
args,
|
444 |
+
False,
|
445 |
+
accelerator,
|
446 |
+
src_path,
|
447 |
+
save_stable_diffusion_format,
|
448 |
+
use_safetensors,
|
449 |
+
save_dtype,
|
450 |
+
epoch,
|
451 |
+
num_train_epochs,
|
452 |
+
global_step,
|
453 |
+
accelerator.unwrap_model(text_encoder),
|
454 |
+
accelerator.unwrap_model(unet),
|
455 |
+
vae,
|
456 |
+
)
|
457 |
+
|
458 |
+
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
459 |
+
if args.logging_dir is not None:
|
460 |
+
logs = {"loss": current_loss}
|
461 |
+
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
|
462 |
+
accelerator.log(logs, step=global_step)
|
463 |
+
|
464 |
+
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
465 |
+
avr_loss: float = loss_recorder.moving_average
|
466 |
+
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
467 |
+
progress_bar.set_postfix(**logs)
|
468 |
+
|
469 |
+
if global_step >= args.max_train_steps:
|
470 |
+
break
|
471 |
+
|
472 |
+
if args.logging_dir is not None:
|
473 |
+
logs = {"loss/epoch": loss_recorder.moving_average}
|
474 |
+
accelerator.log(logs, step=epoch + 1)
|
475 |
+
|
476 |
+
accelerator.wait_for_everyone()
|
477 |
+
|
478 |
+
if args.save_every_n_epochs is not None:
|
479 |
+
if accelerator.is_main_process:
|
480 |
+
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
481 |
+
train_util.save_sd_model_on_epoch_end_or_stepwise(
|
482 |
+
args,
|
483 |
+
True,
|
484 |
+
accelerator,
|
485 |
+
src_path,
|
486 |
+
save_stable_diffusion_format,
|
487 |
+
use_safetensors,
|
488 |
+
save_dtype,
|
489 |
+
epoch,
|
490 |
+
num_train_epochs,
|
491 |
+
global_step,
|
492 |
+
accelerator.unwrap_model(text_encoder),
|
493 |
+
accelerator.unwrap_model(unet),
|
494 |
+
vae,
|
495 |
+
)
|
496 |
+
|
497 |
+
train_util.sample_images(
|
498 |
+
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
|
499 |
+
)
|
500 |
+
|
501 |
+
is_main_process = accelerator.is_main_process
|
502 |
+
if is_main_process:
|
503 |
+
unet = accelerator.unwrap_model(unet)
|
504 |
+
text_encoder = accelerator.unwrap_model(text_encoder)
|
505 |
+
|
506 |
+
accelerator.end_training()
|
507 |
+
|
508 |
+
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
509 |
+
train_util.save_state_on_train_end(args, accelerator)
|
510 |
+
|
511 |
+
del accelerator # この後メモリを使うのでこれは消す
|
512 |
+
|
513 |
+
if is_main_process:
|
514 |
+
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
515 |
+
train_util.save_sd_model_on_train_end(
|
516 |
+
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
|
517 |
+
)
|
518 |
+
logger.info("model saved.")
|
519 |
+
|
520 |
+
|
521 |
+
def setup_parser() -> argparse.ArgumentParser:
|
522 |
+
parser = argparse.ArgumentParser()
|
523 |
+
|
524 |
+
add_logging_arguments(parser)
|
525 |
+
train_util.add_sd_models_arguments(parser)
|
526 |
+
train_util.add_dataset_arguments(parser, False, True, True)
|
527 |
+
train_util.add_training_arguments(parser, False)
|
528 |
+
deepspeed_utils.add_deepspeed_arguments(parser)
|
529 |
+
train_util.add_sd_saving_arguments(parser)
|
530 |
+
train_util.add_optimizer_arguments(parser)
|
531 |
+
config_util.add_config_arguments(parser)
|
532 |
+
custom_train_functions.add_custom_train_arguments(parser)
|
533 |
+
|
534 |
+
parser.add_argument(
|
535 |
+
"--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する"
|
536 |
+
)
|
537 |
+
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
538 |
+
parser.add_argument(
|
539 |
+
"--learning_rate_te",
|
540 |
+
type=float,
|
541 |
+
default=None,
|
542 |
+
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
|
543 |
+
)
|
544 |
+
parser.add_argument(
|
545 |
+
"--no_half_vae",
|
546 |
+
action="store_true",
|
547 |
+
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
548 |
+
)
|
549 |
+
|
550 |
+
return parser
|
551 |
+
|
552 |
+
|
553 |
+
if __name__ == "__main__":
|
554 |
+
parser = setup_parser()
|
555 |
+
|
556 |
+
args = parser.parse_args()
|
557 |
+
train_util.verify_command_line_training_args(args)
|
558 |
+
args = train_util.read_config_from_file(args, parser)
|
559 |
+
|
560 |
+
train(args)
|
flags.png
ADDED
![]() |
flow.gif
ADDED
![]() |
Git LFS Details
|
flux_extract_lora.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# extract approximating LoRA by svd from two FLUX models
|
2 |
+
# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
3 |
+
# Thanks to cloneofsimo!
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import time
|
9 |
+
import torch
|
10 |
+
from safetensors.torch import load_file, save_file
|
11 |
+
from safetensors import safe_open
|
12 |
+
from tqdm import tqdm
|
13 |
+
from .library import flux_utils, sai_model_spec
|
14 |
+
from .library.utils import MemoryEfficientSafeOpen
|
15 |
+
from .library.utils import setup_logging
|
16 |
+
from .networks import lora_flux
|
17 |
+
|
18 |
+
setup_logging()
|
19 |
+
import logging
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
from comfy.utils import ProgressBar
|
24 |
+
# CLAMP_QUANTILE = 0.99
|
25 |
+
# MIN_DIFF = 1e-1
|
26 |
+
|
27 |
+
|
28 |
+
def save_to_file(file_name, state_dict, metadata, dtype):
|
29 |
+
if dtype is not None:
|
30 |
+
for key in list(state_dict.keys()):
|
31 |
+
if type(state_dict[key]) == torch.Tensor:
|
32 |
+
state_dict[key] = state_dict[key].to(dtype)
|
33 |
+
|
34 |
+
save_file(state_dict, file_name, metadata=metadata)
|
35 |
+
|
36 |
+
|
37 |
+
def svd(
|
38 |
+
model_org=None,
|
39 |
+
model_tuned=None,
|
40 |
+
save_to=None,
|
41 |
+
dim=4,
|
42 |
+
device=None,
|
43 |
+
store_device='cpu',
|
44 |
+
save_precision=None,
|
45 |
+
clamp_quantile=0.99,
|
46 |
+
min_diff=0.01,
|
47 |
+
no_metadata=False,
|
48 |
+
mem_eff_safe_open=False,
|
49 |
+
):
|
50 |
+
def str_to_dtype(p):
|
51 |
+
if p == "float":
|
52 |
+
return torch.float
|
53 |
+
if p == "fp16":
|
54 |
+
return torch.float16
|
55 |
+
if p == "bf16":
|
56 |
+
return torch.bfloat16
|
57 |
+
return None
|
58 |
+
|
59 |
+
calc_dtype = torch.float
|
60 |
+
save_dtype = str_to_dtype(save_precision)
|
61 |
+
|
62 |
+
# open models
|
63 |
+
lora_weights = {}
|
64 |
+
if not mem_eff_safe_open:
|
65 |
+
# use original safetensors.safe_open
|
66 |
+
open_fn = lambda fn: safe_open(fn, framework="pt")
|
67 |
+
else:
|
68 |
+
logger.info("Using memory efficient safe_open")
|
69 |
+
open_fn = lambda fn: MemoryEfficientSafeOpen(fn)
|
70 |
+
|
71 |
+
with open_fn(model_org) as fo:
|
72 |
+
# filter keys
|
73 |
+
keys = []
|
74 |
+
for key in fo.keys():
|
75 |
+
if not ("single_block" in key or "double_block" in key):
|
76 |
+
continue
|
77 |
+
if ".bias" in key:
|
78 |
+
continue
|
79 |
+
if "norm" in key:
|
80 |
+
continue
|
81 |
+
keys.append(key)
|
82 |
+
comfy_pbar = ProgressBar(len(keys))
|
83 |
+
with open_fn(model_tuned) as ft:
|
84 |
+
for key in tqdm(keys):
|
85 |
+
# get tensors and calculate difference
|
86 |
+
value_o = fo.get_tensor(key)
|
87 |
+
value_t = ft.get_tensor(key)
|
88 |
+
mat = value_t.to(calc_dtype) - value_o.to(calc_dtype)
|
89 |
+
del value_o, value_t
|
90 |
+
|
91 |
+
# extract LoRA weights
|
92 |
+
if device:
|
93 |
+
mat = mat.to(device)
|
94 |
+
out_dim, in_dim = mat.size()[0:2]
|
95 |
+
rank = min(dim, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
96 |
+
|
97 |
+
mat = mat.squeeze()
|
98 |
+
|
99 |
+
U, S, Vh = torch.linalg.svd(mat)
|
100 |
+
|
101 |
+
U = U[:, :rank]
|
102 |
+
S = S[:rank]
|
103 |
+
U = U @ torch.diag(S)
|
104 |
+
|
105 |
+
Vh = Vh[:rank, :]
|
106 |
+
|
107 |
+
dist = torch.cat([U.flatten(), Vh.flatten()])
|
108 |
+
hi_val = torch.quantile(dist, clamp_quantile)
|
109 |
+
low_val = -hi_val
|
110 |
+
|
111 |
+
U = U.clamp(low_val, hi_val)
|
112 |
+
Vh = Vh.clamp(low_val, hi_val)
|
113 |
+
|
114 |
+
U = U.to(store_device, dtype=save_dtype).contiguous()
|
115 |
+
Vh = Vh.to(store_device, dtype=save_dtype).contiguous()
|
116 |
+
|
117 |
+
print(f"key: {key}, U: {U.size()}, Vh: {Vh.size()}")
|
118 |
+
comfy_pbar.update(1)
|
119 |
+
lora_weights[key] = (U, Vh)
|
120 |
+
del mat, U, S, Vh
|
121 |
+
|
122 |
+
# make state dict for LoRA
|
123 |
+
lora_sd = {}
|
124 |
+
for key, (up_weight, down_weight) in lora_weights.items():
|
125 |
+
lora_name = key.replace(".weight", "").replace(".", "_")
|
126 |
+
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + lora_name
|
127 |
+
lora_sd[lora_name + ".lora_up.weight"] = up_weight
|
128 |
+
lora_sd[lora_name + ".lora_down.weight"] = down_weight
|
129 |
+
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0]) # same as rank
|
130 |
+
|
131 |
+
# minimum metadata
|
132 |
+
net_kwargs = {}
|
133 |
+
metadata = {
|
134 |
+
"ss_v2": str(False),
|
135 |
+
"ss_base_model_version": flux_utils.MODEL_VERSION_FLUX_V1,
|
136 |
+
"ss_network_module": "networks.lora_flux",
|
137 |
+
"ss_network_dim": str(dim),
|
138 |
+
"ss_network_alpha": str(float(dim)),
|
139 |
+
"ss_network_args": json.dumps(net_kwargs),
|
140 |
+
}
|
141 |
+
|
142 |
+
if not no_metadata:
|
143 |
+
title = os.path.splitext(os.path.basename(save_to))[0]
|
144 |
+
sai_metadata = sai_model_spec.build_metadata(lora_sd, False, False, False, True, False, time.time(), title, flux="dev")
|
145 |
+
metadata.update(sai_metadata)
|
146 |
+
|
147 |
+
save_to_file(save_to, lora_sd, metadata, save_dtype)
|
148 |
+
|
149 |
+
logger.info(f"LoRA weights saved to {save_to}")
|
150 |
+
return save_to
|
151 |
+
|
152 |
+
|
153 |
+
def setup_parser() -> argparse.ArgumentParser:
|
154 |
+
parser = argparse.ArgumentParser()
|
155 |
+
parser.add_argument(
|
156 |
+
"--save_precision",
|
157 |
+
type=str,
|
158 |
+
default=None,
|
159 |
+
choices=[None, "float", "fp16", "bf16"],
|
160 |
+
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat",
|
161 |
+
)
|
162 |
+
parser.add_argument(
|
163 |
+
"--model_org",
|
164 |
+
type=str,
|
165 |
+
default=None,
|
166 |
+
required=True,
|
167 |
+
help="Original model: safetensors file / 元モデル、safetensors",
|
168 |
+
)
|
169 |
+
parser.add_argument(
|
170 |
+
"--model_tuned",
|
171 |
+
type=str,
|
172 |
+
default=None,
|
173 |
+
required=True,
|
174 |
+
help="Tuned model, LoRA is difference of `original to tuned`: safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors",
|
175 |
+
)
|
176 |
+
parser.add_argument(
|
177 |
+
"--mem_eff_safe_open",
|
178 |
+
action="store_true",
|
179 |
+
help="use memory efficient safe_open. This is an experimental feature, use only when memory is not enough."
|
180 |
+
" / メモリ効率の良いsafe_openを使用する。実装は実験的なものなので、メモリが足りない場合のみ使用してください。",
|
181 |
+
)
|
182 |
+
parser.add_argument(
|
183 |
+
"--save_to",
|
184 |
+
type=str,
|
185 |
+
default=None,
|
186 |
+
required=True,
|
187 |
+
help="destination file name: safetensors file / 保存先のファイル名、safetensors",
|
188 |
+
)
|
189 |
+
parser.add_argument(
|
190 |
+
"--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)"
|
191 |
+
)
|
192 |
+
parser.add_argument(
|
193 |
+
"--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う"
|
194 |
+
)
|
195 |
+
parser.add_argument(
|
196 |
+
"--clamp_quantile",
|
197 |
+
type=float,
|
198 |
+
default=0.99,
|
199 |
+
help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99",
|
200 |
+
)
|
201 |
+
# parser.add_argument(
|
202 |
+
# "--min_diff",
|
203 |
+
# type=float,
|
204 |
+
# default=0.01,
|
205 |
+
# help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /"
|
206 |
+
# + "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01",
|
207 |
+
# )
|
208 |
+
parser.add_argument(
|
209 |
+
"--no_metadata",
|
210 |
+
action="store_true",
|
211 |
+
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
|
212 |
+
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
|
213 |
+
)
|
214 |
+
return parser
|
215 |
+
|
216 |
+
|
217 |
+
if __name__ == "__main__":
|
218 |
+
parser = setup_parser()
|
219 |
+
|
220 |
+
args = parser.parse_args()
|
221 |
+
svd(**vars(args))
|
flux_train_comfy.py
ADDED
@@ -0,0 +1,806 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# training with captions
|
2 |
+
|
3 |
+
# Swap blocks between CPU and GPU:
|
4 |
+
# This implementation is inspired by and based on the work of 2kpr.
|
5 |
+
# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading.
|
6 |
+
# The original idea has been adapted and extended to fit the current project's needs.
|
7 |
+
|
8 |
+
# Key features:
|
9 |
+
# - CPU offloading during forward and backward passes
|
10 |
+
# - Use of fused optimizer and grad_hook for efficient gradient processing
|
11 |
+
# - Per-block fused optimizer instances
|
12 |
+
|
13 |
+
import argparse
|
14 |
+
import copy
|
15 |
+
import math
|
16 |
+
import os
|
17 |
+
from multiprocessing import Value
|
18 |
+
import toml
|
19 |
+
|
20 |
+
from tqdm import tqdm
|
21 |
+
|
22 |
+
import torch
|
23 |
+
from .library.device_utils import init_ipex, clean_memory_on_device
|
24 |
+
|
25 |
+
init_ipex()
|
26 |
+
|
27 |
+
from accelerate.utils import set_seed
|
28 |
+
from .library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux
|
29 |
+
from .library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
30 |
+
|
31 |
+
from .library import train_util as train_util
|
32 |
+
|
33 |
+
from .library.utils import setup_logging, add_logging_arguments
|
34 |
+
|
35 |
+
setup_logging()
|
36 |
+
import logging
|
37 |
+
|
38 |
+
logger = logging.getLogger(__name__)
|
39 |
+
|
40 |
+
from .library import config_util as config_util
|
41 |
+
|
42 |
+
from .library.config_util import (
|
43 |
+
ConfigSanitizer,
|
44 |
+
BlueprintGenerator,
|
45 |
+
)
|
46 |
+
from .library.custom_train_functions import apply_masked_loss, add_custom_train_arguments
|
47 |
+
|
48 |
+
|
49 |
+
class FluxTrainer:
|
50 |
+
def __init__(self):
|
51 |
+
self.sample_prompts_te_outputs = None
|
52 |
+
|
53 |
+
def sample_images(self, epoch, global_step, validation_settings):
|
54 |
+
image_tensors = flux_train_utils.sample_images(
|
55 |
+
self.accelerator, self.args, epoch, global_step, self.unet, self.vae, self.text_encoder, self.sample_prompts_te_outputs, validation_settings)
|
56 |
+
return image_tensors
|
57 |
+
|
58 |
+
def init_train(self, args):
|
59 |
+
train_util.verify_training_args(args)
|
60 |
+
train_util.prepare_dataset_args(args, True)
|
61 |
+
# sdxl_train_util.verify_sdxl_training_args(args)
|
62 |
+
deepspeed_utils.prepare_deepspeed_args(args)
|
63 |
+
setup_logging(args, reset=True)
|
64 |
+
|
65 |
+
# temporary: backward compatibility for deprecated options. remove in the future
|
66 |
+
if not args.skip_cache_check:
|
67 |
+
args.skip_cache_check = args.skip_latents_validity_check
|
68 |
+
|
69 |
+
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
70 |
+
logger.warning(
|
71 |
+
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
|
72 |
+
)
|
73 |
+
args.cache_text_encoder_outputs = True
|
74 |
+
|
75 |
+
if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
|
76 |
+
logger.warning(
|
77 |
+
"cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります"
|
78 |
+
)
|
79 |
+
args.gradient_checkpointing = True
|
80 |
+
|
81 |
+
assert (
|
82 |
+
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
83 |
+
) or not args.cpu_offload_checkpointing, (
|
84 |
+
"blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
|
85 |
+
)
|
86 |
+
|
87 |
+
cache_latents = args.cache_latents
|
88 |
+
use_dreambooth_method = args.in_json is None
|
89 |
+
|
90 |
+
if args.seed is not None:
|
91 |
+
set_seed(args.seed)
|
92 |
+
|
93 |
+
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
94 |
+
if args.cache_latents:
|
95 |
+
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(
|
96 |
+
args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check
|
97 |
+
)
|
98 |
+
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
99 |
+
|
100 |
+
# Prepare the dataset
|
101 |
+
if args.dataset_class is None:
|
102 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
|
103 |
+
if args.dataset_config is not None:
|
104 |
+
logger.info(f"Load dataset config from {args.dataset_config}")
|
105 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
106 |
+
ignored = ["train_data_dir", "in_json"]
|
107 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
108 |
+
logger.warning(
|
109 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
110 |
+
", ".join(ignored)
|
111 |
+
)
|
112 |
+
)
|
113 |
+
else:
|
114 |
+
if use_dreambooth_method:
|
115 |
+
logger.info("Using DreamBooth method.")
|
116 |
+
user_config = {
|
117 |
+
"datasets": [
|
118 |
+
{
|
119 |
+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
120 |
+
args.train_data_dir, args.reg_data_dir
|
121 |
+
)
|
122 |
+
}
|
123 |
+
]
|
124 |
+
}
|
125 |
+
else:
|
126 |
+
logger.info("Training with captions.")
|
127 |
+
user_config = {
|
128 |
+
"datasets": [
|
129 |
+
{
|
130 |
+
"subsets": [
|
131 |
+
{
|
132 |
+
"image_dir": args.train_data_dir,
|
133 |
+
"metadata_file": args.in_json,
|
134 |
+
}
|
135 |
+
]
|
136 |
+
}
|
137 |
+
]
|
138 |
+
}
|
139 |
+
|
140 |
+
blueprint = blueprint_generator.generate(user_config, args)
|
141 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
142 |
+
else:
|
143 |
+
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
144 |
+
|
145 |
+
current_epoch = Value("i", 0)
|
146 |
+
current_step = Value("i", 0)
|
147 |
+
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
148 |
+
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
149 |
+
|
150 |
+
train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認
|
151 |
+
|
152 |
+
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
153 |
+
if args.debug_dataset:
|
154 |
+
if args.cache_text_encoder_outputs:
|
155 |
+
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
156 |
+
strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
157 |
+
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
|
158 |
+
)
|
159 |
+
)
|
160 |
+
t5xxl_max_token_length = (
|
161 |
+
args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512)
|
162 |
+
)
|
163 |
+
strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length))
|
164 |
+
|
165 |
+
train_dataset_group.set_current_strategies()
|
166 |
+
train_util.debug_dataset(train_dataset_group, True)
|
167 |
+
return
|
168 |
+
if len(train_dataset_group) == 0:
|
169 |
+
logger.error(
|
170 |
+
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
|
171 |
+
)
|
172 |
+
return
|
173 |
+
|
174 |
+
if cache_latents:
|
175 |
+
assert (
|
176 |
+
train_dataset_group.is_latent_cacheable()
|
177 |
+
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
178 |
+
|
179 |
+
if args.cache_text_encoder_outputs:
|
180 |
+
assert (
|
181 |
+
train_dataset_group.is_text_encoder_output_cacheable()
|
182 |
+
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
183 |
+
|
184 |
+
# acceleratorを準備する
|
185 |
+
logger.info("prepare accelerator")
|
186 |
+
accelerator = train_util.prepare_accelerator(args)
|
187 |
+
|
188 |
+
# mixed precisionに対応した型を用意しておき適宜castする
|
189 |
+
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
190 |
+
|
191 |
+
# load VAE for caching latents
|
192 |
+
ae = None
|
193 |
+
if cache_latents:
|
194 |
+
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
|
195 |
+
ae.to(accelerator.device, dtype=weight_dtype)
|
196 |
+
ae.requires_grad_(False)
|
197 |
+
ae.eval()
|
198 |
+
|
199 |
+
train_dataset_group.new_cache_latents(ae, accelerator)
|
200 |
+
|
201 |
+
ae.to("cpu") # if no sampling, vae can be deleted
|
202 |
+
clean_memory_on_device(accelerator.device)
|
203 |
+
|
204 |
+
accelerator.wait_for_everyone()
|
205 |
+
|
206 |
+
# prepare tokenize strategy
|
207 |
+
if args.t5xxl_max_token_length is None:
|
208 |
+
if is_schnell:
|
209 |
+
t5xxl_max_token_length = 256
|
210 |
+
else:
|
211 |
+
t5xxl_max_token_length = 512
|
212 |
+
else:
|
213 |
+
t5xxl_max_token_length = args.t5xxl_max_token_length
|
214 |
+
|
215 |
+
flux_tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)
|
216 |
+
strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy)
|
217 |
+
|
218 |
+
# load clip_l, t5xxl for caching text encoder outputs
|
219 |
+
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
|
220 |
+
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
|
221 |
+
clip_l.eval()
|
222 |
+
t5xxl.eval()
|
223 |
+
clip_l.requires_grad_(False)
|
224 |
+
t5xxl.requires_grad_(False)
|
225 |
+
|
226 |
+
text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask)
|
227 |
+
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
228 |
+
|
229 |
+
# cache text encoder outputs
|
230 |
+
sample_prompts_te_outputs = None
|
231 |
+
if args.cache_text_encoder_outputs:
|
232 |
+
# Text Encodes are eval and no grad here
|
233 |
+
clip_l.to(accelerator.device)
|
234 |
+
t5xxl.to(accelerator.device)
|
235 |
+
|
236 |
+
text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
237 |
+
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask
|
238 |
+
)
|
239 |
+
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
|
240 |
+
|
241 |
+
with accelerator.autocast():
|
242 |
+
train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator)
|
243 |
+
|
244 |
+
# cache sample prompt's embeddings to free text encoder's memory
|
245 |
+
if args.sample_prompts is not None:
|
246 |
+
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
|
247 |
+
|
248 |
+
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
249 |
+
|
250 |
+
prompts = []
|
251 |
+
for line in args.sample_prompts:
|
252 |
+
line = line.strip()
|
253 |
+
if len(line) > 0 and line[0] != "#":
|
254 |
+
prompts.append(line)
|
255 |
+
|
256 |
+
# preprocess prompts
|
257 |
+
for i in range(len(prompts)):
|
258 |
+
prompt_dict = prompts[i]
|
259 |
+
if isinstance(prompt_dict, str):
|
260 |
+
from .library.train_util import line_to_prompt_dict
|
261 |
+
|
262 |
+
prompt_dict = line_to_prompt_dict(prompt_dict)
|
263 |
+
prompts[i] = prompt_dict
|
264 |
+
assert isinstance(prompt_dict, dict)
|
265 |
+
|
266 |
+
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
|
267 |
+
prompt_dict["enum"] = i
|
268 |
+
prompt_dict.pop("subset", None)
|
269 |
+
|
270 |
+
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
271 |
+
with accelerator.autocast(), torch.no_grad():
|
272 |
+
for prompt_dict in prompts:
|
273 |
+
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
274 |
+
if p not in sample_prompts_te_outputs:
|
275 |
+
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
276 |
+
tokens_and_masks = flux_tokenize_strategy.tokenize(p)
|
277 |
+
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
278 |
+
flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
279 |
+
)
|
280 |
+
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
281 |
+
accelerator.wait_for_everyone()
|
282 |
+
|
283 |
+
# now we can delete Text Encoders to free memory
|
284 |
+
clip_l = None
|
285 |
+
t5xxl = None
|
286 |
+
clean_memory_on_device(accelerator.device)
|
287 |
+
|
288 |
+
# load FLUX
|
289 |
+
_, flux = flux_utils.load_flow_model(
|
290 |
+
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
|
291 |
+
)
|
292 |
+
|
293 |
+
if args.gradient_checkpointing:
|
294 |
+
flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing)
|
295 |
+
|
296 |
+
flux.requires_grad_(True)
|
297 |
+
|
298 |
+
# block swap
|
299 |
+
|
300 |
+
# backward compatibility
|
301 |
+
if args.blocks_to_swap is None:
|
302 |
+
blocks_to_swap = args.double_blocks_to_swap or 0
|
303 |
+
if args.single_blocks_to_swap is not None:
|
304 |
+
blocks_to_swap += args.single_blocks_to_swap // 2
|
305 |
+
if blocks_to_swap > 0:
|
306 |
+
logger.warning(
|
307 |
+
"double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead."
|
308 |
+
" / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。"
|
309 |
+
)
|
310 |
+
logger.info(
|
311 |
+
f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}."
|
312 |
+
)
|
313 |
+
args.blocks_to_swap = blocks_to_swap
|
314 |
+
del blocks_to_swap
|
315 |
+
|
316 |
+
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
317 |
+
if self.is_swapping_blocks:
|
318 |
+
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
319 |
+
# This idea is based on 2kpr's great work. Thank you!
|
320 |
+
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
321 |
+
flux.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
322 |
+
|
323 |
+
if not cache_latents:
|
324 |
+
# load VAE here if not cached
|
325 |
+
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu")
|
326 |
+
ae.requires_grad_(False)
|
327 |
+
ae.eval()
|
328 |
+
ae.to(accelerator.device, dtype=weight_dtype)
|
329 |
+
|
330 |
+
training_models = []
|
331 |
+
params_to_optimize = []
|
332 |
+
training_models.append(flux)
|
333 |
+
name_and_params = list(flux.named_parameters())
|
334 |
+
# single param group for now
|
335 |
+
params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate})
|
336 |
+
param_names = [[n for n, _ in name_and_params]]
|
337 |
+
|
338 |
+
# calculate number of trainable parameters
|
339 |
+
n_params = 0
|
340 |
+
for group in params_to_optimize:
|
341 |
+
for p in group["params"]:
|
342 |
+
n_params += p.numel()
|
343 |
+
|
344 |
+
accelerator.print(f"number of trainable parameters: {n_params}")
|
345 |
+
|
346 |
+
# 学習に必要なクラスを準備する
|
347 |
+
accelerator.print("prepare optimizer, data loader etc.")
|
348 |
+
|
349 |
+
if args.blockwise_fused_optimizers:
|
350 |
+
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
|
351 |
+
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
|
352 |
+
# This balances memory usage and management complexity.
|
353 |
+
|
354 |
+
# split params into groups. currently different learning rates are not supported
|
355 |
+
grouped_params = []
|
356 |
+
param_group = {}
|
357 |
+
for group in params_to_optimize:
|
358 |
+
named_parameters = list(flux.named_parameters())
|
359 |
+
assert len(named_parameters) == len(group["params"]), "number of parameters does not match"
|
360 |
+
for p, np in zip(group["params"], named_parameters):
|
361 |
+
# determine target layer and block index for each parameter
|
362 |
+
block_type = "other" # double, single or other
|
363 |
+
if np[0].startswith("double_blocks"):
|
364 |
+
block_index = int(np[0].split(".")[1])
|
365 |
+
block_type = "double"
|
366 |
+
elif np[0].startswith("single_blocks"):
|
367 |
+
block_index = int(np[0].split(".")[1])
|
368 |
+
block_type = "single"
|
369 |
+
else:
|
370 |
+
block_index = -1
|
371 |
+
|
372 |
+
param_group_key = (block_type, block_index)
|
373 |
+
if param_group_key not in param_group:
|
374 |
+
param_group[param_group_key] = []
|
375 |
+
param_group[param_group_key].append(p)
|
376 |
+
|
377 |
+
block_types_and_indices = []
|
378 |
+
for param_group_key, param_group in param_group.items():
|
379 |
+
block_types_and_indices.append(param_group_key)
|
380 |
+
grouped_params.append({"params": param_group, "lr": args.learning_rate})
|
381 |
+
|
382 |
+
num_params = 0
|
383 |
+
for p in param_group:
|
384 |
+
num_params += p.numel()
|
385 |
+
accelerator.print(f"block {param_group_key}: {num_params} parameters")
|
386 |
+
|
387 |
+
# prepare optimizers for each group
|
388 |
+
optimizers = []
|
389 |
+
for group in grouped_params:
|
390 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
|
391 |
+
optimizers.append(optimizer)
|
392 |
+
optimizer = optimizers[0] # avoid error in the following code
|
393 |
+
|
394 |
+
logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers")
|
395 |
+
|
396 |
+
if train_util.is_schedulefree_optimizer(optimizers[0], args):
|
397 |
+
raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
|
398 |
+
self.optimizer_train_fn = lambda: None # dummy function
|
399 |
+
self.optimizer_eval_fn = lambda: None # dummy function
|
400 |
+
else:
|
401 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
402 |
+
self.optimizer_train_fn, self.optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
403 |
+
|
404 |
+
# prepare dataloader
|
405 |
+
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
406 |
+
# some strategies can be None
|
407 |
+
train_dataset_group.set_current_strategies()
|
408 |
+
|
409 |
+
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
410 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
411 |
+
train_dataloader = torch.utils.data.DataLoader(
|
412 |
+
train_dataset_group,
|
413 |
+
batch_size=1,
|
414 |
+
shuffle=True,
|
415 |
+
collate_fn=collator,
|
416 |
+
num_workers=n_workers,
|
417 |
+
persistent_workers=args.persistent_data_loader_workers,
|
418 |
+
)
|
419 |
+
|
420 |
+
# 学習ステップ数を計算する
|
421 |
+
if args.max_train_epochs is not None:
|
422 |
+
args.max_train_steps = args.max_train_epochs * math.ceil(
|
423 |
+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
424 |
+
)
|
425 |
+
accelerator.print(
|
426 |
+
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
427 |
+
)
|
428 |
+
|
429 |
+
# データセット側にも学習ステップを送信
|
430 |
+
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
431 |
+
|
432 |
+
# lr schedulerを用意する
|
433 |
+
if args.blockwise_fused_optimizers:
|
434 |
+
# prepare lr schedulers for each optimizer
|
435 |
+
lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers]
|
436 |
+
lr_scheduler = lr_schedulers[0] # avoid error in the following code
|
437 |
+
else:
|
438 |
+
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
439 |
+
|
440 |
+
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
441 |
+
if args.full_fp16:
|
442 |
+
assert (
|
443 |
+
args.mixed_precision == "fp16"
|
444 |
+
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
445 |
+
accelerator.print("enable full fp16 training.")
|
446 |
+
flux.to(weight_dtype)
|
447 |
+
if clip_l is not None:
|
448 |
+
clip_l.to(weight_dtype)
|
449 |
+
t5xxl.to(weight_dtype) # TODO check works with fp16 or not
|
450 |
+
elif args.full_bf16:
|
451 |
+
assert (
|
452 |
+
args.mixed_precision == "bf16"
|
453 |
+
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
454 |
+
accelerator.print("enable full bf16 training.")
|
455 |
+
flux.to(weight_dtype)
|
456 |
+
if clip_l is not None:
|
457 |
+
clip_l.to(weight_dtype)
|
458 |
+
t5xxl.to(weight_dtype)
|
459 |
+
|
460 |
+
# if we don't cache text encoder outputs, move them to device
|
461 |
+
if not args.cache_text_encoder_outputs:
|
462 |
+
clip_l.to(accelerator.device)
|
463 |
+
t5xxl.to(accelerator.device)
|
464 |
+
|
465 |
+
clean_memory_on_device(accelerator.device)
|
466 |
+
|
467 |
+
if args.deepspeed:
|
468 |
+
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=flux)
|
469 |
+
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
|
470 |
+
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
471 |
+
ds_model, optimizer, train_dataloader, lr_scheduler
|
472 |
+
)
|
473 |
+
training_models = [ds_model]
|
474 |
+
|
475 |
+
else:
|
476 |
+
# accelerator does some magic
|
477 |
+
# if we doesn't swap blocks, we can move the model to device
|
478 |
+
flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks])
|
479 |
+
if self.is_swapping_blocks:
|
480 |
+
accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
|
481 |
+
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
482 |
+
|
483 |
+
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
484 |
+
if args.full_fp16:
|
485 |
+
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
|
486 |
+
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
|
487 |
+
train_util.patch_accelerator_for_fp16_training(accelerator)
|
488 |
+
|
489 |
+
# resumeする
|
490 |
+
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
491 |
+
|
492 |
+
if args.fused_backward_pass:
|
493 |
+
# use fused optimizer for backward pass: other optimizers will be supported in the future
|
494 |
+
from .library import adafactor_fused
|
495 |
+
|
496 |
+
adafactor_fused.patch_adafactor_fused(optimizer)
|
497 |
+
|
498 |
+
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
|
499 |
+
for parameter, param_name in zip(param_group["params"], param_name_group):
|
500 |
+
if parameter.requires_grad:
|
501 |
+
|
502 |
+
def create_grad_hook(p_name, p_group):
|
503 |
+
def grad_hook(tensor: torch.Tensor):
|
504 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
505 |
+
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
506 |
+
optimizer.step_param(tensor, p_group)
|
507 |
+
tensor.grad = None
|
508 |
+
|
509 |
+
return grad_hook
|
510 |
+
|
511 |
+
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
|
512 |
+
|
513 |
+
elif args.blockwise_fused_optimizers:
|
514 |
+
# prepare for additional optimizers and lr schedulers
|
515 |
+
for i in range(1, len(optimizers)):
|
516 |
+
optimizers[i] = accelerator.prepare(optimizers[i])
|
517 |
+
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
|
518 |
+
|
519 |
+
# counters are used to determine when to step the optimizer
|
520 |
+
global optimizer_hooked_count
|
521 |
+
global num_parameters_per_group
|
522 |
+
global parameter_optimizer_map
|
523 |
+
|
524 |
+
optimizer_hooked_count = {}
|
525 |
+
num_parameters_per_group = [0] * len(optimizers)
|
526 |
+
parameter_optimizer_map = {}
|
527 |
+
|
528 |
+
for opt_idx, optimizer in enumerate(optimizers):
|
529 |
+
for param_group in optimizer.param_groups:
|
530 |
+
for parameter in param_group["params"]:
|
531 |
+
if parameter.requires_grad:
|
532 |
+
|
533 |
+
def grad_hook(parameter: torch.Tensor):
|
534 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
535 |
+
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
|
536 |
+
|
537 |
+
i = parameter_optimizer_map[parameter]
|
538 |
+
optimizer_hooked_count[i] += 1
|
539 |
+
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
|
540 |
+
optimizers[i].step()
|
541 |
+
optimizers[i].zero_grad(set_to_none=True)
|
542 |
+
|
543 |
+
parameter.register_post_accumulate_grad_hook(grad_hook)
|
544 |
+
parameter_optimizer_map[parameter] = opt_idx
|
545 |
+
num_parameters_per_group[opt_idx] += 1
|
546 |
+
|
547 |
+
# epoch数を計算する
|
548 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
549 |
+
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
550 |
+
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
551 |
+
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
552 |
+
|
553 |
+
# 学習する
|
554 |
+
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
555 |
+
accelerator.print("running training")
|
556 |
+
accelerator.print(f" num examples: {train_dataset_group.num_train_images}")
|
557 |
+
accelerator.print(f" num batches per epoch: {len(train_dataloader)}")
|
558 |
+
accelerator.print(f" num epochs: {num_train_epochs}")
|
559 |
+
accelerator.print(
|
560 |
+
f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
561 |
+
)
|
562 |
+
# accelerator.print(
|
563 |
+
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
|
564 |
+
# )
|
565 |
+
accelerator.print(f" gradient accumulation steps = {args.gradient_accumulation_steps}")
|
566 |
+
accelerator.print(f" total optimization steps: {args.max_train_steps}")
|
567 |
+
|
568 |
+
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
569 |
+
self.global_step = 0
|
570 |
+
|
571 |
+
noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
572 |
+
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
573 |
+
|
574 |
+
if accelerator.is_main_process:
|
575 |
+
init_kwargs = {}
|
576 |
+
if args.wandb_run_name:
|
577 |
+
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
578 |
+
if args.log_tracker_config is not None:
|
579 |
+
init_kwargs = toml.load(args.log_tracker_config)
|
580 |
+
accelerator.init_trackers(
|
581 |
+
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
|
582 |
+
config=train_util.get_sanitized_config_or_none(args),
|
583 |
+
init_kwargs=init_kwargs,
|
584 |
+
)
|
585 |
+
|
586 |
+
if self.is_swapping_blocks:
|
587 |
+
accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
|
588 |
+
|
589 |
+
# For --sample_at_first
|
590 |
+
#flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs)
|
591 |
+
|
592 |
+
self.loss_recorder = train_util.LossRecorder()
|
593 |
+
epoch = 0 # avoid error when max_train_steps is 0
|
594 |
+
|
595 |
+
self.tokens_and_masks = tokens_and_masks
|
596 |
+
self.num_train_epochs = num_train_epochs
|
597 |
+
self.current_epoch = current_epoch
|
598 |
+
self.args = args
|
599 |
+
self.accelerator = accelerator
|
600 |
+
self.unet = flux
|
601 |
+
self.vae = ae
|
602 |
+
self.text_encoder = [clip_l, t5xxl]
|
603 |
+
self.save_dtype = save_dtype
|
604 |
+
|
605 |
+
def training_loop(break_at_steps, epoch):
|
606 |
+
global optimizer_hooked_count
|
607 |
+
steps_done = 0
|
608 |
+
#accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
609 |
+
progress_bar.set_description(f"Epoch {epoch + 1}/{num_train_epochs} - steps")
|
610 |
+
current_epoch.value = epoch + 1
|
611 |
+
|
612 |
+
for m in training_models:
|
613 |
+
m.train()
|
614 |
+
|
615 |
+
for step, batch in enumerate(train_dataloader):
|
616 |
+
current_step.value = self.global_step
|
617 |
+
|
618 |
+
if args.blockwise_fused_optimizers:
|
619 |
+
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
|
620 |
+
|
621 |
+
with accelerator.accumulate(*training_models):
|
622 |
+
if "latents" in batch and batch["latents"] is not None:
|
623 |
+
latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
|
624 |
+
else:
|
625 |
+
with torch.no_grad():
|
626 |
+
# encode images to latents. images are [-1, 1]
|
627 |
+
latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype)
|
628 |
+
|
629 |
+
# NaNが含まれていれば警告を表示し0に置き換える
|
630 |
+
if torch.any(torch.isnan(latents)):
|
631 |
+
accelerator.print("NaN found in latents, replacing with zeros")
|
632 |
+
latents = torch.nan_to_num(latents, 0, out=latents)
|
633 |
+
|
634 |
+
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
635 |
+
if text_encoder_outputs_list is not None:
|
636 |
+
text_encoder_conds = text_encoder_outputs_list
|
637 |
+
else:
|
638 |
+
# not cached or training, so get from text encoders
|
639 |
+
tokens_and_masks = batch["input_ids_list"]
|
640 |
+
with torch.no_grad():
|
641 |
+
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
642 |
+
text_encoder_conds = text_encoding_strategy.encode_tokens(
|
643 |
+
flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask
|
644 |
+
)
|
645 |
+
if args.full_fp16:
|
646 |
+
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
|
647 |
+
|
648 |
+
# TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps
|
649 |
+
|
650 |
+
# Sample noise that we'll add to the latents
|
651 |
+
noise = torch.randn_like(latents)
|
652 |
+
bsz = latents.shape[0]
|
653 |
+
|
654 |
+
# get noisy model input and timesteps
|
655 |
+
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
656 |
+
args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype
|
657 |
+
)
|
658 |
+
|
659 |
+
# pack latents and get img_ids
|
660 |
+
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
661 |
+
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
662 |
+
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
|
663 |
+
|
664 |
+
# get guidance: ensure args.guidance_scale is float
|
665 |
+
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
|
666 |
+
|
667 |
+
# call model
|
668 |
+
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
669 |
+
if not args.apply_t5_attn_mask:
|
670 |
+
t5_attn_mask = None
|
671 |
+
|
672 |
+
if args.bypass_flux_guidance:
|
673 |
+
flux_utils.bypass_flux_guidance(flux)
|
674 |
+
|
675 |
+
with accelerator.autocast():
|
676 |
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
677 |
+
model_pred = flux(
|
678 |
+
img=packed_noisy_model_input,
|
679 |
+
img_ids=img_ids,
|
680 |
+
txt=t5_out,
|
681 |
+
txt_ids=txt_ids,
|
682 |
+
y=l_pooled,
|
683 |
+
timesteps=timesteps / 1000,
|
684 |
+
guidance=guidance_vec,
|
685 |
+
txt_attention_mask=t5_attn_mask,
|
686 |
+
)
|
687 |
+
|
688 |
+
# unpack latents
|
689 |
+
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
690 |
+
|
691 |
+
if args.bypass_flux_guidance:
|
692 |
+
flux_utils.restore_flux_guidance(flux)
|
693 |
+
|
694 |
+
# apply model prediction type
|
695 |
+
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
696 |
+
|
697 |
+
# flow matching loss: this is different from SD3
|
698 |
+
target = noise - latents
|
699 |
+
|
700 |
+
# calculate loss
|
701 |
+
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
702 |
+
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
703 |
+
if weighting is not None:
|
704 |
+
loss = loss * weighting
|
705 |
+
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
706 |
+
loss = apply_masked_loss(loss, batch)
|
707 |
+
loss = loss.mean([1, 2, 3])
|
708 |
+
|
709 |
+
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
710 |
+
loss = loss * loss_weights
|
711 |
+
loss = loss.mean()
|
712 |
+
|
713 |
+
# backward
|
714 |
+
accelerator.backward(loss)
|
715 |
+
|
716 |
+
if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
|
717 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
718 |
+
params_to_clip = []
|
719 |
+
for m in training_models:
|
720 |
+
params_to_clip.extend(m.parameters())
|
721 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
722 |
+
|
723 |
+
optimizer.step()
|
724 |
+
lr_scheduler.step()
|
725 |
+
optimizer.zero_grad(set_to_none=True)
|
726 |
+
else:
|
727 |
+
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
|
728 |
+
lr_scheduler.step()
|
729 |
+
if args.blockwise_fused_optimizers:
|
730 |
+
for i in range(1, len(optimizers)):
|
731 |
+
lr_schedulers[i].step()
|
732 |
+
|
733 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
734 |
+
if accelerator.sync_gradients:
|
735 |
+
progress_bar.update(1)
|
736 |
+
self.global_step += 1
|
737 |
+
|
738 |
+
|
739 |
+
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
740 |
+
if len(accelerator.trackers) > 0:
|
741 |
+
logs = {"loss": current_loss}
|
742 |
+
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
|
743 |
+
|
744 |
+
accelerator.log(logs, step=self.global_step)
|
745 |
+
|
746 |
+
self.loss_recorder.add(epoch=epoch, step=step, loss=current_loss, global_step=self.global_step)
|
747 |
+
avr_loss: float = self.loss_recorder.moving_average
|
748 |
+
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
749 |
+
progress_bar.set_postfix(**logs)
|
750 |
+
|
751 |
+
if self.global_step >= break_at_steps:
|
752 |
+
break
|
753 |
+
steps_done += 1
|
754 |
+
|
755 |
+
if len(accelerator.trackers) > 0:
|
756 |
+
logs = {"loss/epoch": self.loss_recorder.moving_average}
|
757 |
+
accelerator.log(logs, step=epoch + 1)
|
758 |
+
return steps_done
|
759 |
+
|
760 |
+
return training_loop
|
761 |
+
|
762 |
+
def setup_parser() -> argparse.ArgumentParser:
|
763 |
+
parser = argparse.ArgumentParser()
|
764 |
+
|
765 |
+
add_logging_arguments(parser)
|
766 |
+
train_util.add_sd_models_arguments(parser) # TODO split this
|
767 |
+
train_util.add_dataset_arguments(parser, True, True, True)
|
768 |
+
train_util.add_training_arguments(parser, False)
|
769 |
+
train_util.add_masked_loss_arguments(parser)
|
770 |
+
deepspeed_utils.add_deepspeed_arguments(parser)
|
771 |
+
train_util.add_sd_saving_arguments(parser)
|
772 |
+
train_util.add_optimizer_arguments(parser)
|
773 |
+
config_util.add_config_arguments(parser)
|
774 |
+
add_custom_train_arguments(parser) # TODO remove this from here
|
775 |
+
train_util.add_dit_training_arguments(parser)
|
776 |
+
flux_train_utils.add_flux_train_arguments(parser)
|
777 |
+
|
778 |
+
parser.add_argument(
|
779 |
+
"--mem_eff_save",
|
780 |
+
action="store_true",
|
781 |
+
help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う",
|
782 |
+
)
|
783 |
+
|
784 |
+
parser.add_argument(
|
785 |
+
"--fused_optimizer_groups",
|
786 |
+
type=int,
|
787 |
+
default=None,
|
788 |
+
help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます",
|
789 |
+
)
|
790 |
+
parser.add_argument(
|
791 |
+
"--blockwise_fused_optimizers",
|
792 |
+
action="store_true",
|
793 |
+
help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする",
|
794 |
+
)
|
795 |
+
parser.add_argument(
|
796 |
+
"--skip_latents_validity_check",
|
797 |
+
action="store_true",
|
798 |
+
help="skip latents validity check / latentsの正当性チェックをスキップする",
|
799 |
+
)
|
800 |
+
|
801 |
+
parser.add_argument(
|
802 |
+
"--cpu_offload_checkpointing",
|
803 |
+
action="store_true",
|
804 |
+
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
|
805 |
+
)
|
806 |
+
return parser
|
flux_train_network_comfy.py
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import copy
|
3 |
+
import math
|
4 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
5 |
+
import argparse
|
6 |
+
from .library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
|
7 |
+
from .train_network import NetworkTrainer, clean_memory_on_device, setup_parser
|
8 |
+
|
9 |
+
from accelerate import Accelerator
|
10 |
+
|
11 |
+
|
12 |
+
import logging
|
13 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
class FluxNetworkTrainer(NetworkTrainer):
|
17 |
+
def __init__(self):
|
18 |
+
super().__init__()
|
19 |
+
self.sample_prompts_te_outputs = None
|
20 |
+
self.is_schnell: Optional[bool] = None
|
21 |
+
self.is_swapping_blocks: bool = False
|
22 |
+
|
23 |
+
def assert_extra_args(self, args, train_dataset_group):
|
24 |
+
super().assert_extra_args(args, train_dataset_group)
|
25 |
+
# sdxl_train_util.verify_sdxl_training_args(args)
|
26 |
+
|
27 |
+
if args.fp8_base_unet:
|
28 |
+
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
|
29 |
+
|
30 |
+
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
31 |
+
logger.warning(
|
32 |
+
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
|
33 |
+
)
|
34 |
+
args.cache_text_encoder_outputs = True
|
35 |
+
|
36 |
+
if args.cache_text_encoder_outputs:
|
37 |
+
assert (
|
38 |
+
train_dataset_group.is_text_encoder_output_cacheable()
|
39 |
+
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
40 |
+
|
41 |
+
# prepare CLIP-L/T5XXL training flags
|
42 |
+
self.train_clip_l = not args.network_train_unet_only
|
43 |
+
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
|
44 |
+
|
45 |
+
if args.max_token_length is not None:
|
46 |
+
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
|
47 |
+
|
48 |
+
assert (
|
49 |
+
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
50 |
+
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
|
51 |
+
|
52 |
+
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
53 |
+
|
54 |
+
def load_target_model(self, args, weight_dtype, accelerator):
|
55 |
+
# currently offload to cpu for some models
|
56 |
+
|
57 |
+
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
|
58 |
+
loading_dtype = None if args.fp8_base else weight_dtype
|
59 |
+
|
60 |
+
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
|
61 |
+
self.is_schnell, model = flux_utils.load_flow_model(
|
62 |
+
args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
|
63 |
+
)
|
64 |
+
if args.fp8_base:
|
65 |
+
# check dtype of model
|
66 |
+
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2fnuz:
|
67 |
+
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
|
68 |
+
elif model.dtype == torch.float8_e4m3fn or model.dtype == torch.float8_e5m2:
|
69 |
+
logger.info(f"Loaded {model.dtype} FLUX model")
|
70 |
+
|
71 |
+
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
72 |
+
if self.is_swapping_blocks:
|
73 |
+
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
74 |
+
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
75 |
+
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
76 |
+
|
77 |
+
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
78 |
+
clip_l.eval()
|
79 |
+
|
80 |
+
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
|
81 |
+
if args.fp8_base and not args.fp8_base_unet:
|
82 |
+
loading_dtype = None # as is
|
83 |
+
else:
|
84 |
+
loading_dtype = weight_dtype
|
85 |
+
|
86 |
+
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
|
87 |
+
t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
88 |
+
t5xxl.eval()
|
89 |
+
if args.fp8_base and not args.fp8_base_unet:
|
90 |
+
# check dtype of model
|
91 |
+
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
|
92 |
+
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
|
93 |
+
elif t5xxl.dtype == torch.float8_e4m3fn:
|
94 |
+
logger.info("Loaded fp8 T5XXL model")
|
95 |
+
|
96 |
+
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
97 |
+
|
98 |
+
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
|
99 |
+
|
100 |
+
def get_tokenize_strategy(self, args):
|
101 |
+
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
102 |
+
|
103 |
+
if args.t5xxl_max_token_length is None:
|
104 |
+
if is_schnell:
|
105 |
+
t5xxl_max_token_length = 256
|
106 |
+
else:
|
107 |
+
t5xxl_max_token_length = 512
|
108 |
+
else:
|
109 |
+
t5xxl_max_token_length = args.t5xxl_max_token_length
|
110 |
+
|
111 |
+
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
|
112 |
+
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
|
113 |
+
|
114 |
+
def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
|
115 |
+
return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
|
116 |
+
|
117 |
+
def get_latents_caching_strategy(self, args):
|
118 |
+
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
|
119 |
+
return latents_caching_strategy
|
120 |
+
|
121 |
+
def get_text_encoding_strategy(self, args):
|
122 |
+
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
|
123 |
+
|
124 |
+
def post_process_network(self, args, accelerator, network, text_encoders, unet):
|
125 |
+
# check t5xxl is trained or not
|
126 |
+
self.train_t5xxl = network.train_t5xxl
|
127 |
+
|
128 |
+
if self.train_t5xxl and args.cache_text_encoder_outputs:
|
129 |
+
raise ValueError(
|
130 |
+
"T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
|
131 |
+
)
|
132 |
+
|
133 |
+
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
134 |
+
if args.cache_text_encoder_outputs:
|
135 |
+
if self.train_clip_l and not self.train_t5xxl:
|
136 |
+
return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
|
137 |
+
else:
|
138 |
+
return None # no text encoders are needed for encoding because both are cached
|
139 |
+
else:
|
140 |
+
return text_encoders # both CLIP-L and T5XXL are needed for encoding
|
141 |
+
|
142 |
+
def get_text_encoders_train_flags(self, args, text_encoders):
|
143 |
+
return [self.train_clip_l, self.train_t5xxl]
|
144 |
+
|
145 |
+
def get_text_encoder_outputs_caching_strategy(self, args):
|
146 |
+
if args.cache_text_encoder_outputs:
|
147 |
+
# if the text encoders is trained, we need tokenization, so is_partial is True
|
148 |
+
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
149 |
+
args.cache_text_encoder_outputs_to_disk,
|
150 |
+
args.text_encoder_batch_size,
|
151 |
+
args.skip_cache_check,
|
152 |
+
is_partial=self.train_clip_l or self.train_t5xxl,
|
153 |
+
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
154 |
+
)
|
155 |
+
else:
|
156 |
+
return None
|
157 |
+
|
158 |
+
def cache_text_encoder_outputs_if_needed(
|
159 |
+
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
160 |
+
):
|
161 |
+
if args.cache_text_encoder_outputs:
|
162 |
+
if not args.lowram:
|
163 |
+
# reduce memory consumption
|
164 |
+
logger.info("move vae and unet to cpu to save memory")
|
165 |
+
org_vae_device = vae.device
|
166 |
+
org_unet_device = unet.device
|
167 |
+
vae.to("cpu")
|
168 |
+
unet.to("cpu")
|
169 |
+
clean_memory_on_device(accelerator.device)
|
170 |
+
|
171 |
+
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
172 |
+
logger.info("move text encoders to gpu")
|
173 |
+
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
|
174 |
+
text_encoders[1].to(accelerator.device)
|
175 |
+
|
176 |
+
if text_encoders[1].dtype == torch.float8_e4m3fn:
|
177 |
+
# if we load fp8 weights, the model is already fp8, so we use it as is
|
178 |
+
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
|
179 |
+
else:
|
180 |
+
# otherwise, we need to convert it to target dtype
|
181 |
+
text_encoders[1].to(weight_dtype)
|
182 |
+
|
183 |
+
with accelerator.autocast():
|
184 |
+
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
185 |
+
|
186 |
+
# cache sample prompts
|
187 |
+
if args.sample_prompts is not None:
|
188 |
+
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
|
189 |
+
|
190 |
+
tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
|
191 |
+
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
192 |
+
|
193 |
+
prompts = []
|
194 |
+
for line in args.sample_prompts:
|
195 |
+
line = line.strip()
|
196 |
+
if len(line) > 0 and line[0] != "#":
|
197 |
+
prompts.append(line)
|
198 |
+
|
199 |
+
# preprocess prompts
|
200 |
+
for i in range(len(prompts)):
|
201 |
+
prompt_dict = prompts[i]
|
202 |
+
if isinstance(prompt_dict, str):
|
203 |
+
from .library.train_util import line_to_prompt_dict
|
204 |
+
|
205 |
+
prompt_dict = line_to_prompt_dict(prompt_dict)
|
206 |
+
prompts[i] = prompt_dict
|
207 |
+
assert isinstance(prompt_dict, dict)
|
208 |
+
|
209 |
+
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
|
210 |
+
prompt_dict["enum"] = i
|
211 |
+
prompt_dict.pop("subset", None)
|
212 |
+
|
213 |
+
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
214 |
+
with accelerator.autocast(), torch.no_grad():
|
215 |
+
for prompt_dict in prompts:
|
216 |
+
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
217 |
+
if p not in sample_prompts_te_outputs:
|
218 |
+
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
219 |
+
tokens_and_masks = tokenize_strategy.tokenize(p)
|
220 |
+
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
221 |
+
tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask
|
222 |
+
)
|
223 |
+
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
224 |
+
|
225 |
+
accelerator.wait_for_everyone()
|
226 |
+
|
227 |
+
# move back to cpu
|
228 |
+
if not self.is_train_text_encoder(args):
|
229 |
+
logger.info("move CLIP-L back to cpu")
|
230 |
+
text_encoders[0].to("cpu")
|
231 |
+
logger.info("move t5XXL back to cpu")
|
232 |
+
text_encoders[1].to("cpu")
|
233 |
+
clean_memory_on_device(accelerator.device)
|
234 |
+
|
235 |
+
if not args.lowram:
|
236 |
+
logger.info("move vae and unet back to original device")
|
237 |
+
vae.to(org_vae_device)
|
238 |
+
unet.to(org_unet_device)
|
239 |
+
else:
|
240 |
+
# Text Encoder
|
241 |
+
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
242 |
+
text_encoders[1].to(accelerator.device)
|
243 |
+
|
244 |
+
def sample_images(self, epoch, global_step, validation_settings):
|
245 |
+
text_encoders = self.get_models_for_text_encoding(self.args, self.accelerator, self.text_encoder)
|
246 |
+
|
247 |
+
image_tensors = flux_train_utils.sample_images(
|
248 |
+
self.accelerator, self.args, epoch, global_step, self.unet, self.vae, text_encoders, self.sample_prompts_te_outputs, validation_settings)
|
249 |
+
clean_memory_on_device(self.accelerator.device)
|
250 |
+
return image_tensors
|
251 |
+
|
252 |
+
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
253 |
+
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
254 |
+
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
255 |
+
return noise_scheduler
|
256 |
+
|
257 |
+
def encode_images_to_latents(self, args, accelerator, vae, images):
|
258 |
+
return vae.encode(images)
|
259 |
+
|
260 |
+
def shift_scale_latents(self, args, latents):
|
261 |
+
return latents
|
262 |
+
|
263 |
+
def get_noise_pred_and_target(
|
264 |
+
self,
|
265 |
+
args,
|
266 |
+
accelerator,
|
267 |
+
noise_scheduler,
|
268 |
+
latents,
|
269 |
+
batch,
|
270 |
+
text_encoder_conds,
|
271 |
+
unet: flux_models.Flux,
|
272 |
+
network,
|
273 |
+
weight_dtype,
|
274 |
+
train_unet,
|
275 |
+
):
|
276 |
+
# Sample noise that we'll add to the latents
|
277 |
+
noise = torch.randn_like(latents)
|
278 |
+
bsz = latents.shape[0]
|
279 |
+
|
280 |
+
# get noisy model input and timesteps
|
281 |
+
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
282 |
+
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
|
283 |
+
)
|
284 |
+
|
285 |
+
# pack latents and get img_ids
|
286 |
+
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
287 |
+
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
288 |
+
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
|
289 |
+
|
290 |
+
# get guidance
|
291 |
+
# ensure guidance_scale in args is float
|
292 |
+
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
|
293 |
+
|
294 |
+
# ensure the hidden state will require grad
|
295 |
+
if args.gradient_checkpointing:
|
296 |
+
noisy_model_input.requires_grad_(True)
|
297 |
+
for t in text_encoder_conds:
|
298 |
+
if t is not None and t.dtype.is_floating_point:
|
299 |
+
t.requires_grad_(True)
|
300 |
+
img_ids.requires_grad_(True)
|
301 |
+
guidance_vec.requires_grad_(True)
|
302 |
+
|
303 |
+
# Predict the noise residual
|
304 |
+
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
305 |
+
if not args.apply_t5_attn_mask:
|
306 |
+
t5_attn_mask = None
|
307 |
+
|
308 |
+
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
309 |
+
# normal forward
|
310 |
+
with accelerator.autocast():
|
311 |
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
312 |
+
model_pred = unet(
|
313 |
+
img=img,
|
314 |
+
img_ids=img_ids,
|
315 |
+
txt=t5_out,
|
316 |
+
txt_ids=txt_ids,
|
317 |
+
y=l_pooled,
|
318 |
+
timesteps=timesteps / 1000,
|
319 |
+
guidance=guidance_vec,
|
320 |
+
txt_attention_mask=t5_attn_mask,
|
321 |
+
)
|
322 |
+
"""
|
323 |
+
else:
|
324 |
+
# split forward to reduce memory usage
|
325 |
+
assert network.train_blocks == "single", "train_blocks must be single for split mode"
|
326 |
+
with accelerator.autocast():
|
327 |
+
# move flux lower to cpu, and then move flux upper to gpu
|
328 |
+
unet.to("cpu")
|
329 |
+
clean_memory_on_device(accelerator.device)
|
330 |
+
self.flux_upper.to(accelerator.device)
|
331 |
+
|
332 |
+
# upper model does not require grad
|
333 |
+
with torch.no_grad():
|
334 |
+
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
|
335 |
+
img=packed_noisy_model_input,
|
336 |
+
img_ids=img_ids,
|
337 |
+
txt=t5_out,
|
338 |
+
txt_ids=txt_ids,
|
339 |
+
y=l_pooled,
|
340 |
+
timesteps=timesteps / 1000,
|
341 |
+
guidance=guidance_vec,
|
342 |
+
txt_attention_mask=t5_attn_mask,
|
343 |
+
)
|
344 |
+
|
345 |
+
# move flux upper back to cpu, and then move flux lower to gpu
|
346 |
+
self.flux_upper.to("cpu")
|
347 |
+
clean_memory_on_device(accelerator.device)
|
348 |
+
unet.to(accelerator.device)
|
349 |
+
|
350 |
+
# lower model requires grad
|
351 |
+
intermediate_img.requires_grad_(True)
|
352 |
+
intermediate_txt.requires_grad_(True)
|
353 |
+
vec.requires_grad_(True)
|
354 |
+
pe.requires_grad_(True)
|
355 |
+
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
|
356 |
+
"""
|
357 |
+
|
358 |
+
return model_pred
|
359 |
+
|
360 |
+
if args.bypass_flux_guidance:
|
361 |
+
flux_utils.bypass_flux_guidance(unet)
|
362 |
+
|
363 |
+
model_pred = call_dit(
|
364 |
+
img=packed_noisy_model_input,
|
365 |
+
img_ids=img_ids,
|
366 |
+
t5_out=t5_out,
|
367 |
+
txt_ids=txt_ids,
|
368 |
+
l_pooled=l_pooled,
|
369 |
+
timesteps=timesteps,
|
370 |
+
guidance_vec=guidance_vec,
|
371 |
+
t5_attn_mask=t5_attn_mask,
|
372 |
+
)
|
373 |
+
|
374 |
+
# unpack latents
|
375 |
+
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
376 |
+
|
377 |
+
if args.bypass_flux_guidance: #for flex
|
378 |
+
flux_utils.restore_flux_guidance(unet)
|
379 |
+
|
380 |
+
# apply model prediction type
|
381 |
+
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
382 |
+
|
383 |
+
# flow matching loss: this is different from SD3
|
384 |
+
target = noise - latents
|
385 |
+
|
386 |
+
# differential output preservation
|
387 |
+
if "custom_attributes" in batch:
|
388 |
+
diff_output_pr_indices = []
|
389 |
+
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
390 |
+
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
391 |
+
diff_output_pr_indices.append(i)
|
392 |
+
|
393 |
+
if len(diff_output_pr_indices) > 0:
|
394 |
+
network.set_multiplier(0.0)
|
395 |
+
unet.prepare_block_swap_before_forward()
|
396 |
+
with torch.no_grad():
|
397 |
+
model_pred_prior = call_dit(
|
398 |
+
img=packed_noisy_model_input[diff_output_pr_indices],
|
399 |
+
img_ids=img_ids[diff_output_pr_indices],
|
400 |
+
t5_out=t5_out[diff_output_pr_indices],
|
401 |
+
txt_ids=txt_ids[diff_output_pr_indices],
|
402 |
+
l_pooled=l_pooled[diff_output_pr_indices],
|
403 |
+
timesteps=timesteps[diff_output_pr_indices],
|
404 |
+
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
|
405 |
+
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
|
406 |
+
)
|
407 |
+
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
408 |
+
|
409 |
+
model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width)
|
410 |
+
model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
|
411 |
+
args,
|
412 |
+
model_pred_prior,
|
413 |
+
noisy_model_input[diff_output_pr_indices],
|
414 |
+
sigmas[diff_output_pr_indices] if sigmas is not None else None,
|
415 |
+
)
|
416 |
+
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
417 |
+
|
418 |
+
return model_pred, target, timesteps, weighting
|
419 |
+
|
420 |
+
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
421 |
+
return loss
|
422 |
+
|
423 |
+
def get_sai_model_spec(self, args):
|
424 |
+
return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
|
425 |
+
|
426 |
+
def update_metadata(self, metadata, args):
|
427 |
+
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
|
428 |
+
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
429 |
+
metadata["ss_logit_mean"] = args.logit_mean
|
430 |
+
metadata["ss_logit_std"] = args.logit_std
|
431 |
+
metadata["ss_mode_scale"] = args.mode_scale
|
432 |
+
metadata["ss_guidance_scale"] = args.guidance_scale
|
433 |
+
metadata["ss_timestep_sampling"] = args.timestep_sampling
|
434 |
+
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
|
435 |
+
metadata["ss_model_prediction_type"] = args.model_prediction_type
|
436 |
+
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
437 |
+
|
438 |
+
def is_text_encoder_not_needed_for_training(self, args):
|
439 |
+
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
440 |
+
|
441 |
+
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
442 |
+
if index == 0: # CLIP-L
|
443 |
+
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
|
444 |
+
else: # T5XXL
|
445 |
+
text_encoder.encoder.embed_tokens.requires_grad_(True)
|
446 |
+
|
447 |
+
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
448 |
+
if index == 0: # CLIP-L
|
449 |
+
logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
|
450 |
+
text_encoder.to(te_weight_dtype) # fp8
|
451 |
+
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
452 |
+
else: # T5XXL
|
453 |
+
|
454 |
+
def prepare_fp8(text_encoder, target_dtype):
|
455 |
+
def forward_hook(module):
|
456 |
+
def forward(hidden_states):
|
457 |
+
hidden_gelu = module.act(module.wi_0(hidden_states))
|
458 |
+
hidden_linear = module.wi_1(hidden_states)
|
459 |
+
hidden_states = hidden_gelu * hidden_linear
|
460 |
+
hidden_states = module.dropout(hidden_states)
|
461 |
+
|
462 |
+
hidden_states = module.wo(hidden_states)
|
463 |
+
return hidden_states
|
464 |
+
|
465 |
+
return forward
|
466 |
+
|
467 |
+
for module in text_encoder.modules():
|
468 |
+
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
|
469 |
+
# print("set", module.__class__.__name__, "to", target_dtype)
|
470 |
+
module.to(target_dtype)
|
471 |
+
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
|
472 |
+
# print("set", module.__class__.__name__, "hooks")
|
473 |
+
module.forward = forward_hook(module)
|
474 |
+
|
475 |
+
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
|
476 |
+
logger.info(f"T5XXL already prepared for fp8")
|
477 |
+
else:
|
478 |
+
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
|
479 |
+
text_encoder.to(te_weight_dtype) # fp8
|
480 |
+
prepare_fp8(text_encoder, weight_dtype)
|
481 |
+
|
482 |
+
def prepare_unet_with_accelerator(
|
483 |
+
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
484 |
+
) -> torch.nn.Module:
|
485 |
+
if not self.is_swapping_blocks:
|
486 |
+
return super().prepare_unet_with_accelerator(args, accelerator, unet)
|
487 |
+
|
488 |
+
# if we doesn't swap blocks, we can move the model to device
|
489 |
+
flux: flux_models.Flux = unet
|
490 |
+
flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks])
|
491 |
+
accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
|
492 |
+
accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
|
493 |
+
|
494 |
+
return flux
|
495 |
+
|
496 |
+
|
497 |
+
def setup_parser() -> argparse.ArgumentParser:
|
498 |
+
parser = setup_parser()
|
499 |
+
train_util.add_dit_training_arguments(parser)
|
500 |
+
flux_train_utils.add_flux_train_arguments(parser)
|
hf_token.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"hf_token": "your_token_here"
|
3 |
+
}
|
icon.png
ADDED
![]() |
install.js
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
module.exports = {
|
2 |
+
run: [
|
3 |
+
{
|
4 |
+
method: "shell.run",
|
5 |
+
params: {
|
6 |
+
venv: "env",
|
7 |
+
message: [
|
8 |
+
"git config --global --add safe.directory '*'",
|
9 |
+
"git clone -b sd3 https://github.com/kohya-ss/sd-scripts"
|
10 |
+
]
|
11 |
+
}
|
12 |
+
},
|
13 |
+
{
|
14 |
+
method: "shell.run",
|
15 |
+
params: {
|
16 |
+
path: "sd-scripts",
|
17 |
+
venv: "../env",
|
18 |
+
message: [
|
19 |
+
"uv pip install -r requirements.txt",
|
20 |
+
]
|
21 |
+
}
|
22 |
+
},
|
23 |
+
{
|
24 |
+
method: "shell.run",
|
25 |
+
params: {
|
26 |
+
venv: "env",
|
27 |
+
message: [
|
28 |
+
"pip uninstall -y diffusers[torch] torch torchaudio torchvision",
|
29 |
+
"uv pip install -r requirements.txt",
|
30 |
+
]
|
31 |
+
}
|
32 |
+
},
|
33 |
+
{
|
34 |
+
method: "script.start",
|
35 |
+
params: {
|
36 |
+
uri: "torch.js",
|
37 |
+
params: {
|
38 |
+
venv: "env",
|
39 |
+
// xformers: true // uncomment this line if your project requires xformers
|
40 |
+
}
|
41 |
+
}
|
42 |
+
},
|
43 |
+
{
|
44 |
+
method: "fs.link",
|
45 |
+
params: {
|
46 |
+
drive: {
|
47 |
+
vae: "models/vae",
|
48 |
+
clip: "models/clip",
|
49 |
+
unet: "models/unet",
|
50 |
+
loras: "outputs",
|
51 |
+
},
|
52 |
+
peers: [
|
53 |
+
"https://github.com/pinokiofactory/stable-diffusion-webui-forge.git",
|
54 |
+
"https://github.com/pinokiofactory/comfy.git",
|
55 |
+
"https://github.com/cocktailpeanutlabs/comfyui.git",
|
56 |
+
"https://github.com/cocktailpeanutlabs/fooocus.git",
|
57 |
+
"https://github.com/cocktailpeanutlabs/automatic1111.git",
|
58 |
+
]
|
59 |
+
}
|
60 |
+
},
|
61 |
+
// {
|
62 |
+
// method: "fs.download",
|
63 |
+
// params: {
|
64 |
+
// uri: [
|
65 |
+
// "https://huggingface.co/comfyanonymous/flux_text_encoders/resolve/main/clip_l.safetensors?download=true",
|
66 |
+
// "https://huggingface.co/comfyanonymous/flux_text_encoders/resolve/main/t5xxl_fp16.safetensors?download=true",
|
67 |
+
// ],
|
68 |
+
// dir: "models/clip"
|
69 |
+
// }
|
70 |
+
// },
|
71 |
+
// {
|
72 |
+
// method: "fs.download",
|
73 |
+
// params: {
|
74 |
+
// uri: [
|
75 |
+
// "https://huggingface.co/cocktailpeanut/xulf-dev/resolve/main/ae.sft?download=true",
|
76 |
+
// ],
|
77 |
+
// dir: "models/vae"
|
78 |
+
// }
|
79 |
+
// },
|
80 |
+
// {
|
81 |
+
// method: "fs.download",
|
82 |
+
// params: {
|
83 |
+
// uri: [
|
84 |
+
// "https://huggingface.co/cocktailpeanut/xulf-dev/resolve/main/flux1-dev.sft?download=true",
|
85 |
+
// ],
|
86 |
+
// dir: "models/unet"
|
87 |
+
// }
|
88 |
+
// },
|
89 |
+
{
|
90 |
+
method: "fs.link",
|
91 |
+
params: {
|
92 |
+
venv: "env"
|
93 |
+
}
|
94 |
+
}
|
95 |
+
]
|
96 |
+
}
|
library/__init__.py
ADDED
File without changes
|
library/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (163 Bytes). View file
|
|
library/__pycache__/config_util.cpython-310.pyc
ADDED
Binary file (20.2 kB). View file
|
|
library/__pycache__/custom_offloading_utils.cpython-310.pyc
ADDED
Binary file (6.98 kB). View file
|
|
library/__pycache__/custom_train_functions.cpython-310.pyc
ADDED
Binary file (13.5 kB). View file
|
|
library/__pycache__/deepspeed_utils.cpython-310.pyc
ADDED
Binary file (4.79 kB). View file
|
|
library/__pycache__/device_utils.cpython-310.pyc
ADDED
Binary file (2.07 kB). View file
|
|
library/__pycache__/flux_models.cpython-310.pyc
ADDED
Binary file (30.6 kB). View file
|
|
library/__pycache__/flux_train_utils.cpython-310.pyc
ADDED
Binary file (14.8 kB). View file
|
|
library/__pycache__/flux_utils.cpython-310.pyc
ADDED
Binary file (16.5 kB). View file
|
|
library/__pycache__/huggingface_util.cpython-310.pyc
ADDED
Binary file (2.79 kB). View file
|
|
library/__pycache__/model_util.cpython-310.pyc
ADDED
Binary file (32.8 kB). View file
|
|
library/__pycache__/original_unet.cpython-310.pyc
ADDED
Binary file (44.1 kB). View file
|
|
library/__pycache__/sai_model_spec.cpython-310.pyc
ADDED
Binary file (5.68 kB). View file
|
|
library/__pycache__/sd3_models.cpython-310.pyc
ADDED
Binary file (38.8 kB). View file
|
|
library/__pycache__/sd3_utils.cpython-310.pyc
ADDED
Binary file (8.48 kB). View file
|
|
library/__pycache__/strategy_base.cpython-310.pyc
ADDED
Binary file (17.4 kB). View file
|
|
library/__pycache__/strategy_sd.cpython-310.pyc
ADDED
Binary file (6.7 kB). View file
|
|
library/__pycache__/train_util.cpython-310.pyc
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fa71a44895d0a006e41ba9fadbd0177a9ad5499cc89aeb2266aa1c7a9597e82e
|
3 |
+
size 164434
|
library/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (15.9 kB). View file
|
|
library/adafactor_fused.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from transformers import Adafactor
|
4 |
+
|
5 |
+
# stochastic rounding for bfloat16
|
6 |
+
# The implementation was provided by 2kpr. Thank you very much!
|
7 |
+
|
8 |
+
def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
|
9 |
+
"""
|
10 |
+
copies source into target using stochastic rounding
|
11 |
+
|
12 |
+
Args:
|
13 |
+
target: the target tensor with dtype=bfloat16
|
14 |
+
source: the target tensor with dtype=float32
|
15 |
+
"""
|
16 |
+
# create a random 16 bit integer
|
17 |
+
result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
|
18 |
+
|
19 |
+
# add the random number to the lower 16 bit of the mantissa
|
20 |
+
result.add_(source.view(dtype=torch.int32))
|
21 |
+
|
22 |
+
# mask off the lower 16 bit of the mantissa
|
23 |
+
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
|
24 |
+
|
25 |
+
# copy the higher 16 bit into the target tensor
|
26 |
+
target.copy_(result.view(dtype=torch.float32))
|
27 |
+
|
28 |
+
del result
|
29 |
+
|
30 |
+
|
31 |
+
@torch.no_grad()
|
32 |
+
def adafactor_step_param(self, p, group):
|
33 |
+
if p.grad is None:
|
34 |
+
return
|
35 |
+
grad = p.grad
|
36 |
+
if grad.dtype in {torch.float16, torch.bfloat16}:
|
37 |
+
grad = grad.float()
|
38 |
+
if grad.is_sparse:
|
39 |
+
raise RuntimeError("Adafactor does not support sparse gradients.")
|
40 |
+
|
41 |
+
state = self.state[p]
|
42 |
+
grad_shape = grad.shape
|
43 |
+
|
44 |
+
factored, use_first_moment = Adafactor._get_options(group, grad_shape)
|
45 |
+
# State Initialization
|
46 |
+
if len(state) == 0:
|
47 |
+
state["step"] = 0
|
48 |
+
|
49 |
+
if use_first_moment:
|
50 |
+
# Exponential moving average of gradient values
|
51 |
+
state["exp_avg"] = torch.zeros_like(grad)
|
52 |
+
if factored:
|
53 |
+
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
|
54 |
+
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
|
55 |
+
else:
|
56 |
+
state["exp_avg_sq"] = torch.zeros_like(grad)
|
57 |
+
|
58 |
+
state["RMS"] = 0
|
59 |
+
else:
|
60 |
+
if use_first_moment:
|
61 |
+
state["exp_avg"] = state["exp_avg"].to(grad)
|
62 |
+
if factored:
|
63 |
+
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
|
64 |
+
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
|
65 |
+
else:
|
66 |
+
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
|
67 |
+
|
68 |
+
p_data_fp32 = p
|
69 |
+
if p.dtype in {torch.float16, torch.bfloat16}:
|
70 |
+
p_data_fp32 = p_data_fp32.float()
|
71 |
+
|
72 |
+
state["step"] += 1
|
73 |
+
state["RMS"] = Adafactor._rms(p_data_fp32)
|
74 |
+
lr = Adafactor._get_lr(group, state)
|
75 |
+
|
76 |
+
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
|
77 |
+
update = (grad**2) + group["eps"][0]
|
78 |
+
if factored:
|
79 |
+
exp_avg_sq_row = state["exp_avg_sq_row"]
|
80 |
+
exp_avg_sq_col = state["exp_avg_sq_col"]
|
81 |
+
|
82 |
+
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
|
83 |
+
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
|
84 |
+
|
85 |
+
# Approximation of exponential moving average of square of gradient
|
86 |
+
update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
87 |
+
update.mul_(grad)
|
88 |
+
else:
|
89 |
+
exp_avg_sq = state["exp_avg_sq"]
|
90 |
+
|
91 |
+
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
|
92 |
+
update = exp_avg_sq.rsqrt().mul_(grad)
|
93 |
+
|
94 |
+
update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
|
95 |
+
update.mul_(lr)
|
96 |
+
|
97 |
+
if use_first_moment:
|
98 |
+
exp_avg = state["exp_avg"]
|
99 |
+
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
|
100 |
+
update = exp_avg
|
101 |
+
|
102 |
+
if group["weight_decay"] != 0:
|
103 |
+
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
|
104 |
+
|
105 |
+
p_data_fp32.add_(-update)
|
106 |
+
|
107 |
+
# if p.dtype in {torch.float16, torch.bfloat16}:
|
108 |
+
# p.copy_(p_data_fp32)
|
109 |
+
|
110 |
+
if p.dtype == torch.bfloat16:
|
111 |
+
copy_stochastic_(p, p_data_fp32)
|
112 |
+
elif p.dtype == torch.float16:
|
113 |
+
p.copy_(p_data_fp32)
|
114 |
+
|
115 |
+
|
116 |
+
@torch.no_grad()
|
117 |
+
def adafactor_step(self, closure=None):
|
118 |
+
"""
|
119 |
+
Performs a single optimization step
|
120 |
+
|
121 |
+
Arguments:
|
122 |
+
closure (callable, optional): A closure that reevaluates the model
|
123 |
+
and returns the loss.
|
124 |
+
"""
|
125 |
+
loss = None
|
126 |
+
if closure is not None:
|
127 |
+
loss = closure()
|
128 |
+
|
129 |
+
for group in self.param_groups:
|
130 |
+
for p in group["params"]:
|
131 |
+
adafactor_step_param(self, p, group)
|
132 |
+
|
133 |
+
return loss
|
134 |
+
|
135 |
+
|
136 |
+
def patch_adafactor_fused(optimizer: Adafactor):
|
137 |
+
optimizer.step_param = adafactor_step_param.__get__(optimizer)
|
138 |
+
optimizer.step = adafactor_step.__get__(optimizer)
|
library/attention_processors.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Any
|
3 |
+
from einops import rearrange
|
4 |
+
import torch
|
5 |
+
from diffusers.models.attention_processor import Attention
|
6 |
+
|
7 |
+
|
8 |
+
# flash attention forwards and backwards
|
9 |
+
|
10 |
+
# https://arxiv.org/abs/2205.14135
|
11 |
+
|
12 |
+
EPSILON = 1e-6
|
13 |
+
|
14 |
+
|
15 |
+
class FlashAttentionFunction(torch.autograd.function.Function):
|
16 |
+
@staticmethod
|
17 |
+
@torch.no_grad()
|
18 |
+
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
19 |
+
"""Algorithm 2 in the paper"""
|
20 |
+
|
21 |
+
device = q.device
|
22 |
+
dtype = q.dtype
|
23 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
24 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
25 |
+
|
26 |
+
o = torch.zeros_like(q)
|
27 |
+
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
|
28 |
+
all_row_maxes = torch.full(
|
29 |
+
(*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device
|
30 |
+
)
|
31 |
+
|
32 |
+
scale = q.shape[-1] ** -0.5
|
33 |
+
|
34 |
+
if mask is None:
|
35 |
+
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
36 |
+
else:
|
37 |
+
mask = rearrange(mask, "b n -> b 1 1 n")
|
38 |
+
mask = mask.split(q_bucket_size, dim=-1)
|
39 |
+
|
40 |
+
row_splits = zip(
|
41 |
+
q.split(q_bucket_size, dim=-2),
|
42 |
+
o.split(q_bucket_size, dim=-2),
|
43 |
+
mask,
|
44 |
+
all_row_sums.split(q_bucket_size, dim=-2),
|
45 |
+
all_row_maxes.split(q_bucket_size, dim=-2),
|
46 |
+
)
|
47 |
+
|
48 |
+
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
49 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
50 |
+
|
51 |
+
col_splits = zip(
|
52 |
+
k.split(k_bucket_size, dim=-2),
|
53 |
+
v.split(k_bucket_size, dim=-2),
|
54 |
+
)
|
55 |
+
|
56 |
+
for k_ind, (kc, vc) in enumerate(col_splits):
|
57 |
+
k_start_index = k_ind * k_bucket_size
|
58 |
+
|
59 |
+
attn_weights = (
|
60 |
+
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
61 |
+
)
|
62 |
+
|
63 |
+
if row_mask is not None:
|
64 |
+
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
65 |
+
|
66 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
67 |
+
causal_mask = torch.ones(
|
68 |
+
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
|
69 |
+
).triu(q_start_index - k_start_index + 1)
|
70 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
71 |
+
|
72 |
+
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
73 |
+
attn_weights -= block_row_maxes
|
74 |
+
exp_weights = torch.exp(attn_weights)
|
75 |
+
|
76 |
+
if row_mask is not None:
|
77 |
+
exp_weights.masked_fill_(~row_mask, 0.0)
|
78 |
+
|
79 |
+
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
|
80 |
+
min=EPSILON
|
81 |
+
)
|
82 |
+
|
83 |
+
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
84 |
+
|
85 |
+
exp_values = torch.einsum(
|
86 |
+
"... i j, ... j d -> ... i d", exp_weights, vc
|
87 |
+
)
|
88 |
+
|
89 |
+
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
90 |
+
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
91 |
+
|
92 |
+
new_row_sums = (
|
93 |
+
exp_row_max_diff * row_sums
|
94 |
+
+ exp_block_row_max_diff * block_row_sums
|
95 |
+
)
|
96 |
+
|
97 |
+
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
|
98 |
+
(exp_block_row_max_diff / new_row_sums) * exp_values
|
99 |
+
)
|
100 |
+
|
101 |
+
row_maxes.copy_(new_row_maxes)
|
102 |
+
row_sums.copy_(new_row_sums)
|
103 |
+
|
104 |
+
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
105 |
+
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
106 |
+
|
107 |
+
return o
|
108 |
+
|
109 |
+
@staticmethod
|
110 |
+
@torch.no_grad()
|
111 |
+
def backward(ctx, do):
|
112 |
+
"""Algorithm 4 in the paper"""
|
113 |
+
|
114 |
+
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
115 |
+
q, k, v, o, l, m = ctx.saved_tensors
|
116 |
+
|
117 |
+
device = q.device
|
118 |
+
|
119 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
120 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
121 |
+
|
122 |
+
dq = torch.zeros_like(q)
|
123 |
+
dk = torch.zeros_like(k)
|
124 |
+
dv = torch.zeros_like(v)
|
125 |
+
|
126 |
+
row_splits = zip(
|
127 |
+
q.split(q_bucket_size, dim=-2),
|
128 |
+
o.split(q_bucket_size, dim=-2),
|
129 |
+
do.split(q_bucket_size, dim=-2),
|
130 |
+
mask,
|
131 |
+
l.split(q_bucket_size, dim=-2),
|
132 |
+
m.split(q_bucket_size, dim=-2),
|
133 |
+
dq.split(q_bucket_size, dim=-2),
|
134 |
+
)
|
135 |
+
|
136 |
+
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
137 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
138 |
+
|
139 |
+
col_splits = zip(
|
140 |
+
k.split(k_bucket_size, dim=-2),
|
141 |
+
v.split(k_bucket_size, dim=-2),
|
142 |
+
dk.split(k_bucket_size, dim=-2),
|
143 |
+
dv.split(k_bucket_size, dim=-2),
|
144 |
+
)
|
145 |
+
|
146 |
+
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
147 |
+
k_start_index = k_ind * k_bucket_size
|
148 |
+
|
149 |
+
attn_weights = (
|
150 |
+
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
151 |
+
)
|
152 |
+
|
153 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
154 |
+
causal_mask = torch.ones(
|
155 |
+
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
|
156 |
+
).triu(q_start_index - k_start_index + 1)
|
157 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
158 |
+
|
159 |
+
exp_attn_weights = torch.exp(attn_weights - mc)
|
160 |
+
|
161 |
+
if row_mask is not None:
|
162 |
+
exp_attn_weights.masked_fill_(~row_mask, 0.0)
|
163 |
+
|
164 |
+
p = exp_attn_weights / lc
|
165 |
+
|
166 |
+
dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
|
167 |
+
dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
|
168 |
+
|
169 |
+
D = (doc * oc).sum(dim=-1, keepdims=True)
|
170 |
+
ds = p * scale * (dp - D)
|
171 |
+
|
172 |
+
dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
|
173 |
+
dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
|
174 |
+
|
175 |
+
dqc.add_(dq_chunk)
|
176 |
+
dkc.add_(dk_chunk)
|
177 |
+
dvc.add_(dv_chunk)
|
178 |
+
|
179 |
+
return dq, dk, dv, None, None, None, None
|
180 |
+
|
181 |
+
|
182 |
+
class FlashAttnProcessor:
|
183 |
+
def __call__(
|
184 |
+
self,
|
185 |
+
attn: Attention,
|
186 |
+
hidden_states,
|
187 |
+
encoder_hidden_states=None,
|
188 |
+
attention_mask=None,
|
189 |
+
) -> Any:
|
190 |
+
q_bucket_size = 512
|
191 |
+
k_bucket_size = 1024
|
192 |
+
|
193 |
+
h = attn.heads
|
194 |
+
q = attn.to_q(hidden_states)
|
195 |
+
|
196 |
+
encoder_hidden_states = (
|
197 |
+
encoder_hidden_states
|
198 |
+
if encoder_hidden_states is not None
|
199 |
+
else hidden_states
|
200 |
+
)
|
201 |
+
encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype)
|
202 |
+
|
203 |
+
if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None:
|
204 |
+
context_k, context_v = attn.hypernetwork.forward(
|
205 |
+
hidden_states, encoder_hidden_states
|
206 |
+
)
|
207 |
+
context_k = context_k.to(hidden_states.dtype)
|
208 |
+
context_v = context_v.to(hidden_states.dtype)
|
209 |
+
else:
|
210 |
+
context_k = encoder_hidden_states
|
211 |
+
context_v = encoder_hidden_states
|
212 |
+
|
213 |
+
k = attn.to_k(context_k)
|
214 |
+
v = attn.to_v(context_v)
|
215 |
+
del encoder_hidden_states, hidden_states
|
216 |
+
|
217 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
218 |
+
|
219 |
+
out = FlashAttentionFunction.apply(
|
220 |
+
q, k, v, attention_mask, False, q_bucket_size, k_bucket_size
|
221 |
+
)
|
222 |
+
|
223 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
224 |
+
|
225 |
+
out = attn.to_out[0](out)
|
226 |
+
out = attn.to_out[1](out)
|
227 |
+
return out
|
library/config_util.py
ADDED
@@ -0,0 +1,717 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from dataclasses import (
|
3 |
+
asdict,
|
4 |
+
dataclass,
|
5 |
+
)
|
6 |
+
import functools
|
7 |
+
import random
|
8 |
+
from textwrap import dedent, indent
|
9 |
+
import json
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
# from toolz import curry
|
13 |
+
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
14 |
+
|
15 |
+
import toml
|
16 |
+
import voluptuous
|
17 |
+
from voluptuous import (
|
18 |
+
Any,
|
19 |
+
ExactSequence,
|
20 |
+
MultipleInvalid,
|
21 |
+
Object,
|
22 |
+
Required,
|
23 |
+
Schema,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
from . import train_util
|
28 |
+
from .train_util import (
|
29 |
+
DreamBoothSubset,
|
30 |
+
FineTuningSubset,
|
31 |
+
ControlNetSubset,
|
32 |
+
DreamBoothDataset,
|
33 |
+
FineTuningDataset,
|
34 |
+
ControlNetDataset,
|
35 |
+
DatasetGroup,
|
36 |
+
)
|
37 |
+
from .utils import setup_logging
|
38 |
+
|
39 |
+
setup_logging()
|
40 |
+
import logging
|
41 |
+
|
42 |
+
logger = logging.getLogger(__name__)
|
43 |
+
|
44 |
+
|
45 |
+
def add_config_arguments(parser: argparse.ArgumentParser):
|
46 |
+
parser.add_argument(
|
47 |
+
"--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
# TODO: inherit Params class in Subset, Dataset
|
52 |
+
|
53 |
+
|
54 |
+
@dataclass
|
55 |
+
class BaseSubsetParams:
|
56 |
+
image_dir: Optional[str] = None
|
57 |
+
num_repeats: int = 1
|
58 |
+
shuffle_caption: bool = False
|
59 |
+
caption_separator: str = (",",)
|
60 |
+
keep_tokens: int = 0
|
61 |
+
keep_tokens_separator: str = (None,)
|
62 |
+
secondary_separator: Optional[str] = None
|
63 |
+
enable_wildcard: bool = False
|
64 |
+
color_aug: bool = False
|
65 |
+
flip_aug: bool = False
|
66 |
+
face_crop_aug_range: Optional[Tuple[float, float]] = None
|
67 |
+
random_crop: bool = False
|
68 |
+
caption_prefix: Optional[str] = None
|
69 |
+
caption_suffix: Optional[str] = None
|
70 |
+
caption_dropout_rate: float = 0.0
|
71 |
+
caption_dropout_every_n_epochs: int = 0
|
72 |
+
caption_tag_dropout_rate: float = 0.0
|
73 |
+
token_warmup_min: int = 1
|
74 |
+
token_warmup_step: float = 0
|
75 |
+
custom_attributes: Optional[Dict[str, Any]] = None
|
76 |
+
|
77 |
+
|
78 |
+
@dataclass
|
79 |
+
class DreamBoothSubsetParams(BaseSubsetParams):
|
80 |
+
is_reg: bool = False
|
81 |
+
class_tokens: Optional[str] = None
|
82 |
+
caption_extension: str = ".caption"
|
83 |
+
cache_info: bool = False
|
84 |
+
alpha_mask: bool = False
|
85 |
+
|
86 |
+
|
87 |
+
@dataclass
|
88 |
+
class FineTuningSubsetParams(BaseSubsetParams):
|
89 |
+
metadata_file: Optional[str] = None
|
90 |
+
alpha_mask: bool = False
|
91 |
+
|
92 |
+
|
93 |
+
@dataclass
|
94 |
+
class ControlNetSubsetParams(BaseSubsetParams):
|
95 |
+
conditioning_data_dir: str = None
|
96 |
+
caption_extension: str = ".caption"
|
97 |
+
cache_info: bool = False
|
98 |
+
|
99 |
+
|
100 |
+
@dataclass
|
101 |
+
class BaseDatasetParams:
|
102 |
+
resolution: Optional[Tuple[int, int]] = None
|
103 |
+
network_multiplier: float = 1.0
|
104 |
+
debug_dataset: bool = False
|
105 |
+
|
106 |
+
|
107 |
+
@dataclass
|
108 |
+
class DreamBoothDatasetParams(BaseDatasetParams):
|
109 |
+
batch_size: int = 1
|
110 |
+
enable_bucket: bool = False
|
111 |
+
min_bucket_reso: int = 256
|
112 |
+
max_bucket_reso: int = 1024
|
113 |
+
bucket_reso_steps: int = 64
|
114 |
+
bucket_no_upscale: bool = False
|
115 |
+
prior_loss_weight: float = 1.0
|
116 |
+
|
117 |
+
|
118 |
+
@dataclass
|
119 |
+
class FineTuningDatasetParams(BaseDatasetParams):
|
120 |
+
batch_size: int = 1
|
121 |
+
enable_bucket: bool = False
|
122 |
+
min_bucket_reso: int = 256
|
123 |
+
max_bucket_reso: int = 1024
|
124 |
+
bucket_reso_steps: int = 64
|
125 |
+
bucket_no_upscale: bool = False
|
126 |
+
|
127 |
+
|
128 |
+
@dataclass
|
129 |
+
class ControlNetDatasetParams(BaseDatasetParams):
|
130 |
+
batch_size: int = 1
|
131 |
+
enable_bucket: bool = False
|
132 |
+
min_bucket_reso: int = 256
|
133 |
+
max_bucket_reso: int = 1024
|
134 |
+
bucket_reso_steps: int = 64
|
135 |
+
bucket_no_upscale: bool = False
|
136 |
+
|
137 |
+
|
138 |
+
@dataclass
|
139 |
+
class SubsetBlueprint:
|
140 |
+
params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
|
141 |
+
|
142 |
+
|
143 |
+
@dataclass
|
144 |
+
class DatasetBlueprint:
|
145 |
+
is_dreambooth: bool
|
146 |
+
is_controlnet: bool
|
147 |
+
params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
|
148 |
+
subsets: Sequence[SubsetBlueprint]
|
149 |
+
|
150 |
+
|
151 |
+
@dataclass
|
152 |
+
class DatasetGroupBlueprint:
|
153 |
+
datasets: Sequence[DatasetBlueprint]
|
154 |
+
|
155 |
+
|
156 |
+
@dataclass
|
157 |
+
class Blueprint:
|
158 |
+
dataset_group: DatasetGroupBlueprint
|
159 |
+
|
160 |
+
|
161 |
+
class ConfigSanitizer:
|
162 |
+
# @curry
|
163 |
+
@staticmethod
|
164 |
+
def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
|
165 |
+
Schema(ExactSequence([klass, klass]))(value)
|
166 |
+
return tuple(value)
|
167 |
+
|
168 |
+
# @curry
|
169 |
+
@staticmethod
|
170 |
+
def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
|
171 |
+
Schema(Any(klass, ExactSequence([klass, klass])))(value)
|
172 |
+
try:
|
173 |
+
Schema(klass)(value)
|
174 |
+
return (value, value)
|
175 |
+
except:
|
176 |
+
return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
|
177 |
+
|
178 |
+
# subset schema
|
179 |
+
SUBSET_ASCENDABLE_SCHEMA = {
|
180 |
+
"color_aug": bool,
|
181 |
+
"face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float),
|
182 |
+
"flip_aug": bool,
|
183 |
+
"num_repeats": int,
|
184 |
+
"random_crop": bool,
|
185 |
+
"shuffle_caption": bool,
|
186 |
+
"keep_tokens": int,
|
187 |
+
"keep_tokens_separator": str,
|
188 |
+
"secondary_separator": str,
|
189 |
+
"caption_separator": str,
|
190 |
+
"enable_wildcard": bool,
|
191 |
+
"token_warmup_min": int,
|
192 |
+
"token_warmup_step": Any(float, int),
|
193 |
+
"caption_prefix": str,
|
194 |
+
"caption_suffix": str,
|
195 |
+
"custom_attributes": dict,
|
196 |
+
}
|
197 |
+
# DO means DropOut
|
198 |
+
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
199 |
+
"caption_dropout_every_n_epochs": int,
|
200 |
+
"caption_dropout_rate": Any(float, int),
|
201 |
+
"caption_tag_dropout_rate": Any(float, int),
|
202 |
+
}
|
203 |
+
# DB means DreamBooth
|
204 |
+
DB_SUBSET_ASCENDABLE_SCHEMA = {
|
205 |
+
"caption_extension": str,
|
206 |
+
"class_tokens": str,
|
207 |
+
"cache_info": bool,
|
208 |
+
}
|
209 |
+
DB_SUBSET_DISTINCT_SCHEMA = {
|
210 |
+
Required("image_dir"): str,
|
211 |
+
"is_reg": bool,
|
212 |
+
"alpha_mask": bool,
|
213 |
+
}
|
214 |
+
# FT means FineTuning
|
215 |
+
FT_SUBSET_DISTINCT_SCHEMA = {
|
216 |
+
Required("metadata_file"): str,
|
217 |
+
"image_dir": str,
|
218 |
+
"alpha_mask": bool,
|
219 |
+
}
|
220 |
+
CN_SUBSET_ASCENDABLE_SCHEMA = {
|
221 |
+
"caption_extension": str,
|
222 |
+
"cache_info": bool,
|
223 |
+
}
|
224 |
+
CN_SUBSET_DISTINCT_SCHEMA = {
|
225 |
+
Required("image_dir"): str,
|
226 |
+
Required("conditioning_data_dir"): str,
|
227 |
+
}
|
228 |
+
|
229 |
+
# datasets schema
|
230 |
+
DATASET_ASCENDABLE_SCHEMA = {
|
231 |
+
"batch_size": int,
|
232 |
+
"bucket_no_upscale": bool,
|
233 |
+
"bucket_reso_steps": int,
|
234 |
+
"enable_bucket": bool,
|
235 |
+
"max_bucket_reso": int,
|
236 |
+
"min_bucket_reso": int,
|
237 |
+
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
238 |
+
"network_multiplier": float,
|
239 |
+
}
|
240 |
+
|
241 |
+
# options handled by argparse but not handled by user config
|
242 |
+
ARGPARSE_SPECIFIC_SCHEMA = {
|
243 |
+
"debug_dataset": bool,
|
244 |
+
"max_token_length": Any(None, int),
|
245 |
+
"prior_loss_weight": Any(float, int),
|
246 |
+
}
|
247 |
+
# for handling default None value of argparse
|
248 |
+
ARGPARSE_NULLABLE_OPTNAMES = [
|
249 |
+
"face_crop_aug_range",
|
250 |
+
"resolution",
|
251 |
+
]
|
252 |
+
# prepare map because option name may differ among argparse and user config
|
253 |
+
ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
|
254 |
+
"train_batch_size": "batch_size",
|
255 |
+
"dataset_repeats": "num_repeats",
|
256 |
+
}
|
257 |
+
|
258 |
+
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
|
259 |
+
assert support_dreambooth or support_finetuning or support_controlnet, (
|
260 |
+
"Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
|
261 |
+
+ " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。"
|
262 |
+
)
|
263 |
+
|
264 |
+
self.db_subset_schema = self.__merge_dict(
|
265 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
266 |
+
self.DB_SUBSET_DISTINCT_SCHEMA,
|
267 |
+
self.DB_SUBSET_ASCENDABLE_SCHEMA,
|
268 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
269 |
+
)
|
270 |
+
|
271 |
+
self.ft_subset_schema = self.__merge_dict(
|
272 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
273 |
+
self.FT_SUBSET_DISTINCT_SCHEMA,
|
274 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
275 |
+
)
|
276 |
+
|
277 |
+
self.cn_subset_schema = self.__merge_dict(
|
278 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
279 |
+
self.CN_SUBSET_DISTINCT_SCHEMA,
|
280 |
+
self.CN_SUBSET_ASCENDABLE_SCHEMA,
|
281 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
282 |
+
)
|
283 |
+
|
284 |
+
self.db_dataset_schema = self.__merge_dict(
|
285 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
286 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
287 |
+
self.DB_SUBSET_ASCENDABLE_SCHEMA,
|
288 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
289 |
+
{"subsets": [self.db_subset_schema]},
|
290 |
+
)
|
291 |
+
|
292 |
+
self.ft_dataset_schema = self.__merge_dict(
|
293 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
294 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
295 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
296 |
+
{"subsets": [self.ft_subset_schema]},
|
297 |
+
)
|
298 |
+
|
299 |
+
self.cn_dataset_schema = self.__merge_dict(
|
300 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
301 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
302 |
+
self.CN_SUBSET_ASCENDABLE_SCHEMA,
|
303 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
304 |
+
{"subsets": [self.cn_subset_schema]},
|
305 |
+
)
|
306 |
+
|
307 |
+
if support_dreambooth and support_finetuning:
|
308 |
+
|
309 |
+
def validate_flex_dataset(dataset_config: dict):
|
310 |
+
subsets_config = dataset_config.get("subsets", [])
|
311 |
+
|
312 |
+
if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]):
|
313 |
+
return Schema(self.cn_dataset_schema)(dataset_config)
|
314 |
+
# check dataset meets FT style
|
315 |
+
# NOTE: all FT subsets should have "metadata_file"
|
316 |
+
elif all(["metadata_file" in subset for subset in subsets_config]):
|
317 |
+
return Schema(self.ft_dataset_schema)(dataset_config)
|
318 |
+
# check dataset meets DB style
|
319 |
+
# NOTE: all DB subsets should have no "metadata_file"
|
320 |
+
elif all(["metadata_file" not in subset for subset in subsets_config]):
|
321 |
+
return Schema(self.db_dataset_schema)(dataset_config)
|
322 |
+
else:
|
323 |
+
raise voluptuous.Invalid(
|
324 |
+
"DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。"
|
325 |
+
)
|
326 |
+
|
327 |
+
self.dataset_schema = validate_flex_dataset
|
328 |
+
elif support_dreambooth:
|
329 |
+
if support_controlnet:
|
330 |
+
self.dataset_schema = self.cn_dataset_schema
|
331 |
+
else:
|
332 |
+
self.dataset_schema = self.db_dataset_schema
|
333 |
+
elif support_finetuning:
|
334 |
+
self.dataset_schema = self.ft_dataset_schema
|
335 |
+
elif support_controlnet:
|
336 |
+
self.dataset_schema = self.cn_dataset_schema
|
337 |
+
|
338 |
+
self.general_schema = self.__merge_dict(
|
339 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
340 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
341 |
+
self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
|
342 |
+
self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {},
|
343 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
344 |
+
)
|
345 |
+
|
346 |
+
self.user_config_validator = Schema(
|
347 |
+
{
|
348 |
+
"general": self.general_schema,
|
349 |
+
"datasets": [self.dataset_schema],
|
350 |
+
}
|
351 |
+
)
|
352 |
+
|
353 |
+
self.argparse_schema = self.__merge_dict(
|
354 |
+
self.general_schema,
|
355 |
+
self.ARGPARSE_SPECIFIC_SCHEMA,
|
356 |
+
{optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES},
|
357 |
+
{a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()},
|
358 |
+
)
|
359 |
+
|
360 |
+
self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
|
361 |
+
|
362 |
+
def sanitize_user_config(self, user_config: dict) -> dict:
|
363 |
+
try:
|
364 |
+
return self.user_config_validator(user_config)
|
365 |
+
except MultipleInvalid:
|
366 |
+
# TODO: エラー発生時のメッセージをわかりやすくする
|
367 |
+
logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
|
368 |
+
raise
|
369 |
+
|
370 |
+
# NOTE: In nature, argument parser result is not needed to be sanitize
|
371 |
+
# However this will help us to detect program bug
|
372 |
+
def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
|
373 |
+
try:
|
374 |
+
return self.argparse_config_validator(argparse_namespace)
|
375 |
+
except MultipleInvalid:
|
376 |
+
# XXX: this should be a bug
|
377 |
+
logger.error(
|
378 |
+
"Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
|
379 |
+
)
|
380 |
+
raise
|
381 |
+
|
382 |
+
# NOTE: value would be overwritten by latter dict if there is already the same key
|
383 |
+
@staticmethod
|
384 |
+
def __merge_dict(*dict_list: dict) -> dict:
|
385 |
+
merged = {}
|
386 |
+
for schema in dict_list:
|
387 |
+
# merged |= schema
|
388 |
+
for k, v in schema.items():
|
389 |
+
merged[k] = v
|
390 |
+
return merged
|
391 |
+
|
392 |
+
|
393 |
+
class BlueprintGenerator:
|
394 |
+
BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
|
395 |
+
|
396 |
+
def __init__(self, sanitizer: ConfigSanitizer):
|
397 |
+
self.sanitizer = sanitizer
|
398 |
+
|
399 |
+
# runtime_params is for parameters which is only configurable on runtime, such as tokenizer
|
400 |
+
def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
|
401 |
+
sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
|
402 |
+
sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
|
403 |
+
|
404 |
+
# convert argparse namespace to dict like config
|
405 |
+
# NOTE: it is ok to have extra entries in dict
|
406 |
+
optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME
|
407 |
+
argparse_config = {
|
408 |
+
optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()
|
409 |
+
}
|
410 |
+
|
411 |
+
general_config = sanitized_user_config.get("general", {})
|
412 |
+
|
413 |
+
dataset_blueprints = []
|
414 |
+
for dataset_config in sanitized_user_config.get("datasets", []):
|
415 |
+
# NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
|
416 |
+
subsets = dataset_config.get("subsets", [])
|
417 |
+
is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
|
418 |
+
is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets])
|
419 |
+
if is_controlnet:
|
420 |
+
subset_params_klass = ControlNetSubsetParams
|
421 |
+
dataset_params_klass = ControlNetDatasetParams
|
422 |
+
elif is_dreambooth:
|
423 |
+
subset_params_klass = DreamBoothSubsetParams
|
424 |
+
dataset_params_klass = DreamBoothDatasetParams
|
425 |
+
else:
|
426 |
+
subset_params_klass = FineTuningSubsetParams
|
427 |
+
dataset_params_klass = FineTuningDatasetParams
|
428 |
+
|
429 |
+
subset_blueprints = []
|
430 |
+
for subset_config in subsets:
|
431 |
+
params = self.generate_params_by_fallbacks(
|
432 |
+
subset_params_klass, [subset_config, dataset_config, general_config, argparse_config, runtime_params]
|
433 |
+
)
|
434 |
+
subset_blueprints.append(SubsetBlueprint(params))
|
435 |
+
|
436 |
+
params = self.generate_params_by_fallbacks(
|
437 |
+
dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
|
438 |
+
)
|
439 |
+
dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints))
|
440 |
+
|
441 |
+
dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
|
442 |
+
|
443 |
+
return Blueprint(dataset_group_blueprint)
|
444 |
+
|
445 |
+
@staticmethod
|
446 |
+
def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
|
447 |
+
name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
|
448 |
+
search_value = BlueprintGenerator.search_value
|
449 |
+
default_params = asdict(param_klass())
|
450 |
+
param_names = default_params.keys()
|
451 |
+
|
452 |
+
params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
|
453 |
+
|
454 |
+
return param_klass(**params)
|
455 |
+
|
456 |
+
@staticmethod
|
457 |
+
def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
|
458 |
+
for cand in fallbacks:
|
459 |
+
value = cand.get(key)
|
460 |
+
if value is not None:
|
461 |
+
return value
|
462 |
+
|
463 |
+
return default_value
|
464 |
+
|
465 |
+
|
466 |
+
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
|
467 |
+
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
468 |
+
|
469 |
+
for dataset_blueprint in dataset_group_blueprint.datasets:
|
470 |
+
if dataset_blueprint.is_controlnet:
|
471 |
+
subset_klass = ControlNetSubset
|
472 |
+
dataset_klass = ControlNetDataset
|
473 |
+
elif dataset_blueprint.is_dreambooth:
|
474 |
+
subset_klass = DreamBoothSubset
|
475 |
+
dataset_klass = DreamBoothDataset
|
476 |
+
else:
|
477 |
+
subset_klass = FineTuningSubset
|
478 |
+
dataset_klass = FineTuningDataset
|
479 |
+
|
480 |
+
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
481 |
+
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
|
482 |
+
datasets.append(dataset)
|
483 |
+
|
484 |
+
# print info
|
485 |
+
info = ""
|
486 |
+
for i, dataset in enumerate(datasets):
|
487 |
+
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
488 |
+
is_controlnet = isinstance(dataset, ControlNetDataset)
|
489 |
+
info += dedent(
|
490 |
+
f"""\
|
491 |
+
[Dataset {i}]
|
492 |
+
batch_size: {dataset.batch_size}
|
493 |
+
resolution: {(dataset.width, dataset.height)}
|
494 |
+
enable_bucket: {dataset.enable_bucket}
|
495 |
+
network_multiplier: {dataset.network_multiplier}
|
496 |
+
"""
|
497 |
+
)
|
498 |
+
|
499 |
+
if dataset.enable_bucket:
|
500 |
+
info += indent(
|
501 |
+
dedent(
|
502 |
+
f"""\
|
503 |
+
min_bucket_reso: {dataset.min_bucket_reso}
|
504 |
+
max_bucket_reso: {dataset.max_bucket_reso}
|
505 |
+
bucket_reso_steps: {dataset.bucket_reso_steps}
|
506 |
+
bucket_no_upscale: {dataset.bucket_no_upscale}
|
507 |
+
\n"""
|
508 |
+
),
|
509 |
+
" ",
|
510 |
+
)
|
511 |
+
else:
|
512 |
+
info += "\n"
|
513 |
+
|
514 |
+
for j, subset in enumerate(dataset.subsets):
|
515 |
+
info += indent(
|
516 |
+
dedent(
|
517 |
+
f"""\
|
518 |
+
[Subset {j} of Dataset {i}]
|
519 |
+
image_dir: "{subset.image_dir}"
|
520 |
+
image_count: {subset.img_count}
|
521 |
+
num_repeats: {subset.num_repeats}
|
522 |
+
shuffle_caption: {subset.shuffle_caption}
|
523 |
+
keep_tokens: {subset.keep_tokens}
|
524 |
+
keep_tokens_separator: {subset.keep_tokens_separator}
|
525 |
+
caption_separator: {subset.caption_separator}
|
526 |
+
secondary_separator: {subset.secondary_separator}
|
527 |
+
enable_wildcard: {subset.enable_wildcard}
|
528 |
+
caption_dropout_rate: {subset.caption_dropout_rate}
|
529 |
+
caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
|
530 |
+
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
531 |
+
caption_prefix: {subset.caption_prefix}
|
532 |
+
caption_suffix: {subset.caption_suffix}
|
533 |
+
color_aug: {subset.color_aug}
|
534 |
+
flip_aug: {subset.flip_aug}
|
535 |
+
face_crop_aug_range: {subset.face_crop_aug_range}
|
536 |
+
random_crop: {subset.random_crop}
|
537 |
+
token_warmup_min: {subset.token_warmup_min}
|
538 |
+
token_warmup_step: {subset.token_warmup_step}
|
539 |
+
alpha_mask: {subset.alpha_mask}
|
540 |
+
custom_attributes: {subset.custom_attributes}
|
541 |
+
"""
|
542 |
+
),
|
543 |
+
" ",
|
544 |
+
)
|
545 |
+
|
546 |
+
if is_dreambooth:
|
547 |
+
info += indent(
|
548 |
+
dedent(
|
549 |
+
f"""\
|
550 |
+
is_reg: {subset.is_reg}
|
551 |
+
class_tokens: {subset.class_tokens}
|
552 |
+
caption_extension: {subset.caption_extension}
|
553 |
+
\n"""
|
554 |
+
),
|
555 |
+
" ",
|
556 |
+
)
|
557 |
+
elif not is_controlnet:
|
558 |
+
info += indent(
|
559 |
+
dedent(
|
560 |
+
f"""\
|
561 |
+
metadata_file: {subset.metadata_file}
|
562 |
+
\n"""
|
563 |
+
),
|
564 |
+
" ",
|
565 |
+
)
|
566 |
+
|
567 |
+
logger.info(f"{info}")
|
568 |
+
|
569 |
+
# make buckets first because it determines the length of dataset
|
570 |
+
# and set the same seed for all datasets
|
571 |
+
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
572 |
+
for i, dataset in enumerate(datasets):
|
573 |
+
logger.info(f"[Dataset {i}]")
|
574 |
+
dataset.make_buckets()
|
575 |
+
dataset.set_seed(seed)
|
576 |
+
|
577 |
+
return DatasetGroup(datasets)
|
578 |
+
|
579 |
+
|
580 |
+
def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
|
581 |
+
def extract_dreambooth_params(name: str) -> Tuple[int, str]:
|
582 |
+
tokens = name.split("_")
|
583 |
+
try:
|
584 |
+
n_repeats = int(tokens[0])
|
585 |
+
except ValueError as e:
|
586 |
+
logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
|
587 |
+
return 0, ""
|
588 |
+
caption_by_folder = "_".join(tokens[1:])
|
589 |
+
return n_repeats, caption_by_folder
|
590 |
+
|
591 |
+
def generate(base_dir: Optional[str], is_reg: bool):
|
592 |
+
if base_dir is None:
|
593 |
+
return []
|
594 |
+
|
595 |
+
base_dir: Path = Path(base_dir)
|
596 |
+
if not base_dir.is_dir():
|
597 |
+
return []
|
598 |
+
|
599 |
+
subsets_config = []
|
600 |
+
for subdir in base_dir.iterdir():
|
601 |
+
if not subdir.is_dir():
|
602 |
+
continue
|
603 |
+
|
604 |
+
num_repeats, class_tokens = extract_dreambooth_params(subdir.name)
|
605 |
+
if num_repeats < 1:
|
606 |
+
continue
|
607 |
+
|
608 |
+
subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens}
|
609 |
+
subsets_config.append(subset_config)
|
610 |
+
|
611 |
+
return subsets_config
|
612 |
+
|
613 |
+
subsets_config = []
|
614 |
+
subsets_config += generate(train_data_dir, False)
|
615 |
+
subsets_config += generate(reg_data_dir, True)
|
616 |
+
|
617 |
+
return subsets_config
|
618 |
+
|
619 |
+
|
620 |
+
def generate_controlnet_subsets_config_by_subdirs(
|
621 |
+
train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"
|
622 |
+
):
|
623 |
+
def generate(base_dir: Optional[str]):
|
624 |
+
if base_dir is None:
|
625 |
+
return []
|
626 |
+
|
627 |
+
base_dir: Path = Path(base_dir)
|
628 |
+
if not base_dir.is_dir():
|
629 |
+
return []
|
630 |
+
|
631 |
+
subsets_config = []
|
632 |
+
subset_config = {
|
633 |
+
"image_dir": train_data_dir,
|
634 |
+
"conditioning_data_dir": conditioning_data_dir,
|
635 |
+
"caption_extension": caption_extension,
|
636 |
+
"num_repeats": 1,
|
637 |
+
}
|
638 |
+
subsets_config.append(subset_config)
|
639 |
+
|
640 |
+
return subsets_config
|
641 |
+
|
642 |
+
subsets_config = []
|
643 |
+
subsets_config += generate(train_data_dir)
|
644 |
+
|
645 |
+
return subsets_config
|
646 |
+
|
647 |
+
|
648 |
+
def load_user_config(file: str) -> dict:
|
649 |
+
file_path: Path = Path(file)
|
650 |
+
if not file_path.is_file():
|
651 |
+
#raise ValueError(f"file not found / ファイルが見つかりません: {file}")
|
652 |
+
return toml.loads(file)
|
653 |
+
|
654 |
+
if file_path.name.lower().endswith(".json"):
|
655 |
+
try:
|
656 |
+
with open(file, "r") as f:
|
657 |
+
config = json.load(f)
|
658 |
+
except Exception:
|
659 |
+
logger.error(
|
660 |
+
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
661 |
+
)
|
662 |
+
raise
|
663 |
+
elif file_path.name.lower().endswith(".toml"):
|
664 |
+
try:
|
665 |
+
config = toml.load(file_path)
|
666 |
+
except Exception:
|
667 |
+
logger.error(
|
668 |
+
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
669 |
+
)
|
670 |
+
raise
|
671 |
+
else:
|
672 |
+
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file_path}")
|
673 |
+
|
674 |
+
return config
|
675 |
+
|
676 |
+
|
677 |
+
# for config test
|
678 |
+
if __name__ == "__main__":
|
679 |
+
parser = argparse.ArgumentParser()
|
680 |
+
parser.add_argument("--support_dreambooth", action="store_true")
|
681 |
+
parser.add_argument("--support_finetuning", action="store_true")
|
682 |
+
parser.add_argument("--support_controlnet", action="store_true")
|
683 |
+
parser.add_argument("--support_dropout", action="store_true")
|
684 |
+
parser.add_argument("dataset_config")
|
685 |
+
config_args, remain = parser.parse_known_args()
|
686 |
+
|
687 |
+
parser = argparse.ArgumentParser()
|
688 |
+
train_util.add_dataset_arguments(
|
689 |
+
parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout
|
690 |
+
)
|
691 |
+
train_util.add_training_arguments(parser, config_args.support_dreambooth)
|
692 |
+
argparse_namespace = parser.parse_args(remain)
|
693 |
+
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
|
694 |
+
|
695 |
+
logger.info("[argparse_namespace]")
|
696 |
+
logger.info(f"{vars(argparse_namespace)}")
|
697 |
+
|
698 |
+
user_config = load_user_config(config_args.dataset_config)
|
699 |
+
|
700 |
+
logger.info("")
|
701 |
+
logger.info("[user_config]")
|
702 |
+
logger.info(f"{user_config}")
|
703 |
+
|
704 |
+
sanitizer = ConfigSanitizer(
|
705 |
+
config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
|
706 |
+
)
|
707 |
+
sanitized_user_config = sanitizer.sanitize_user_config(user_config)
|
708 |
+
|
709 |
+
logger.info("")
|
710 |
+
logger.info("[sanitized_user_config]")
|
711 |
+
logger.info(f"{sanitized_user_config}")
|
712 |
+
|
713 |
+
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
|
714 |
+
|
715 |
+
logger.info("")
|
716 |
+
logger.info("[blueprint]")
|
717 |
+
logger.info(f"{blueprint}")
|
library/custom_offloading_utils.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from concurrent.futures import ThreadPoolExecutor
|
2 |
+
import time
|
3 |
+
from typing import Optional
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from .device_utils import clean_memory_on_device
|
8 |
+
|
9 |
+
|
10 |
+
def synchronize_device(device: torch.device):
|
11 |
+
if device.type == "cuda":
|
12 |
+
torch.cuda.synchronize()
|
13 |
+
elif device.type == "xpu":
|
14 |
+
torch.xpu.synchronize()
|
15 |
+
elif device.type == "mps":
|
16 |
+
torch.mps.synchronize()
|
17 |
+
|
18 |
+
|
19 |
+
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
20 |
+
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
21 |
+
|
22 |
+
weight_swap_jobs = []
|
23 |
+
|
24 |
+
# This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
|
25 |
+
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
26 |
+
# print(module_to_cpu.__class__, module_to_cuda.__class__)
|
27 |
+
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
28 |
+
# weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
29 |
+
|
30 |
+
modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()}
|
31 |
+
for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules():
|
32 |
+
if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None:
|
33 |
+
module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None)
|
34 |
+
if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape:
|
35 |
+
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
36 |
+
else:
|
37 |
+
if module_to_cuda.weight.data.device.type != device.type:
|
38 |
+
# print(
|
39 |
+
# f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
|
40 |
+
# )
|
41 |
+
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device)
|
42 |
+
|
43 |
+
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
44 |
+
|
45 |
+
stream = torch.cuda.Stream()
|
46 |
+
with torch.cuda.stream(stream):
|
47 |
+
# cuda to cpu
|
48 |
+
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
49 |
+
cuda_data_view.record_stream(stream)
|
50 |
+
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
51 |
+
|
52 |
+
stream.synchronize()
|
53 |
+
|
54 |
+
# cpu to cuda
|
55 |
+
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
56 |
+
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
57 |
+
module_to_cuda.weight.data = cuda_data_view
|
58 |
+
|
59 |
+
stream.synchronize()
|
60 |
+
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
61 |
+
|
62 |
+
|
63 |
+
def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
64 |
+
"""
|
65 |
+
not tested
|
66 |
+
"""
|
67 |
+
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
68 |
+
|
69 |
+
weight_swap_jobs = []
|
70 |
+
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
71 |
+
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
72 |
+
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
73 |
+
|
74 |
+
# device to cpu
|
75 |
+
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
76 |
+
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
77 |
+
|
78 |
+
synchronize_device()
|
79 |
+
|
80 |
+
# cpu to device
|
81 |
+
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
82 |
+
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
83 |
+
module_to_cuda.weight.data = cuda_data_view
|
84 |
+
|
85 |
+
synchronize_device()
|
86 |
+
|
87 |
+
|
88 |
+
def weighs_to_device(layer: nn.Module, device: torch.device):
|
89 |
+
for module in layer.modules():
|
90 |
+
if hasattr(module, "weight") and module.weight is not None:
|
91 |
+
module.weight.data = module.weight.data.to(device, non_blocking=True)
|
92 |
+
|
93 |
+
|
94 |
+
class Offloader:
|
95 |
+
"""
|
96 |
+
common offloading class
|
97 |
+
"""
|
98 |
+
|
99 |
+
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
|
100 |
+
self.num_blocks = num_blocks
|
101 |
+
self.blocks_to_swap = blocks_to_swap
|
102 |
+
self.device = device
|
103 |
+
self.debug = debug
|
104 |
+
|
105 |
+
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
106 |
+
self.futures = {}
|
107 |
+
self.cuda_available = device.type == "cuda"
|
108 |
+
|
109 |
+
def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module):
|
110 |
+
if self.cuda_available:
|
111 |
+
swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda)
|
112 |
+
else:
|
113 |
+
swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda)
|
114 |
+
|
115 |
+
def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda):
|
116 |
+
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
|
117 |
+
if self.debug:
|
118 |
+
start_time = time.perf_counter()
|
119 |
+
print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}")
|
120 |
+
|
121 |
+
self.swap_weight_devices(block_to_cpu, block_to_cuda)
|
122 |
+
|
123 |
+
if self.debug:
|
124 |
+
print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
|
125 |
+
return bidx_to_cpu, bidx_to_cuda # , event
|
126 |
+
|
127 |
+
block_to_cpu = blocks[block_idx_to_cpu]
|
128 |
+
block_to_cuda = blocks[block_idx_to_cuda]
|
129 |
+
|
130 |
+
self.futures[block_idx_to_cuda] = self.thread_pool.submit(
|
131 |
+
move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda
|
132 |
+
)
|
133 |
+
|
134 |
+
def _wait_blocks_move(self, block_idx):
|
135 |
+
if block_idx not in self.futures:
|
136 |
+
return
|
137 |
+
|
138 |
+
if self.debug:
|
139 |
+
print(f"Wait for block {block_idx}")
|
140 |
+
start_time = time.perf_counter()
|
141 |
+
|
142 |
+
future = self.futures.pop(block_idx)
|
143 |
+
_, bidx_to_cuda = future.result()
|
144 |
+
|
145 |
+
assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
|
146 |
+
|
147 |
+
if self.debug:
|
148 |
+
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
|
149 |
+
|
150 |
+
|
151 |
+
class ModelOffloader(Offloader):
|
152 |
+
"""
|
153 |
+
supports forward offloading
|
154 |
+
"""
|
155 |
+
|
156 |
+
def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
|
157 |
+
super().__init__(num_blocks, blocks_to_swap, device, debug)
|
158 |
+
|
159 |
+
# register backward hooks
|
160 |
+
self.remove_handles = []
|
161 |
+
for i, block in enumerate(blocks):
|
162 |
+
hook = self.create_backward_hook(blocks, i)
|
163 |
+
if hook is not None:
|
164 |
+
handle = block.register_full_backward_hook(hook)
|
165 |
+
self.remove_handles.append(handle)
|
166 |
+
|
167 |
+
def __del__(self):
|
168 |
+
for handle in self.remove_handles:
|
169 |
+
handle.remove()
|
170 |
+
|
171 |
+
def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
|
172 |
+
# -1 for 0-based index
|
173 |
+
num_blocks_propagated = self.num_blocks - block_index - 1
|
174 |
+
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
|
175 |
+
waiting = block_index > 0 and block_index <= self.blocks_to_swap
|
176 |
+
|
177 |
+
if not swapping and not waiting:
|
178 |
+
return None
|
179 |
+
|
180 |
+
# create hook
|
181 |
+
block_idx_to_cpu = self.num_blocks - num_blocks_propagated
|
182 |
+
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
|
183 |
+
block_idx_to_wait = block_index - 1
|
184 |
+
|
185 |
+
def backward_hook(module, grad_input, grad_output):
|
186 |
+
if self.debug:
|
187 |
+
print(f"Backward hook for block {block_index}")
|
188 |
+
|
189 |
+
if swapping:
|
190 |
+
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
|
191 |
+
if waiting:
|
192 |
+
self._wait_blocks_move(block_idx_to_wait)
|
193 |
+
return None
|
194 |
+
|
195 |
+
return backward_hook
|
196 |
+
|
197 |
+
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
|
198 |
+
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
199 |
+
return
|
200 |
+
|
201 |
+
if self.debug:
|
202 |
+
print("Prepare block devices before forward")
|
203 |
+
|
204 |
+
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
|
205 |
+
b.to(self.device)
|
206 |
+
weighs_to_device(b, self.device) # make sure weights are on device
|
207 |
+
|
208 |
+
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
|
209 |
+
b.to(self.device) # move block to device first
|
210 |
+
weighs_to_device(b, "cpu") # make sure weights are on cpu
|
211 |
+
|
212 |
+
synchronize_device(self.device)
|
213 |
+
clean_memory_on_device(self.device)
|
214 |
+
|
215 |
+
def wait_for_block(self, block_idx: int):
|
216 |
+
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
217 |
+
return
|
218 |
+
self._wait_blocks_move(block_idx)
|
219 |
+
|
220 |
+
def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
|
221 |
+
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
222 |
+
return
|
223 |
+
if block_idx >= self.blocks_to_swap:
|
224 |
+
return
|
225 |
+
block_idx_to_cpu = block_idx
|
226 |
+
block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx
|
227 |
+
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
|
library/custom_train_functions.py
ADDED
@@ -0,0 +1,556 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import argparse
|
3 |
+
import random
|
4 |
+
import re
|
5 |
+
from typing import List, Optional, Union
|
6 |
+
from .utils import setup_logging
|
7 |
+
|
8 |
+
setup_logging()
|
9 |
+
import logging
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
15 |
+
if hasattr(noise_scheduler, "all_snr"):
|
16 |
+
return
|
17 |
+
|
18 |
+
alphas_cumprod = noise_scheduler.alphas_cumprod
|
19 |
+
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
20 |
+
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
21 |
+
alpha = sqrt_alphas_cumprod
|
22 |
+
sigma = sqrt_one_minus_alphas_cumprod
|
23 |
+
all_snr = (alpha / sigma) ** 2
|
24 |
+
|
25 |
+
noise_scheduler.all_snr = all_snr.to(device)
|
26 |
+
|
27 |
+
|
28 |
+
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
29 |
+
# fix beta: zero terminal SNR
|
30 |
+
logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
|
31 |
+
|
32 |
+
def enforce_zero_terminal_snr(betas):
|
33 |
+
# Convert betas to alphas_bar_sqrt
|
34 |
+
alphas = 1 - betas
|
35 |
+
alphas_bar = alphas.cumprod(0)
|
36 |
+
alphas_bar_sqrt = alphas_bar.sqrt()
|
37 |
+
|
38 |
+
# Store old values.
|
39 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
40 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
41 |
+
# Shift so last timestep is zero.
|
42 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
43 |
+
# Scale so first timestep is back to old value.
|
44 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
45 |
+
|
46 |
+
# Convert alphas_bar_sqrt to betas
|
47 |
+
alphas_bar = alphas_bar_sqrt**2
|
48 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
49 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
50 |
+
betas = 1 - alphas
|
51 |
+
return betas
|
52 |
+
|
53 |
+
betas = noise_scheduler.betas
|
54 |
+
betas = enforce_zero_terminal_snr(betas)
|
55 |
+
alphas = 1.0 - betas
|
56 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
57 |
+
|
58 |
+
# logger.info(f"original: {noise_scheduler.betas}")
|
59 |
+
# logger.info(f"fixed: {betas}")
|
60 |
+
|
61 |
+
noise_scheduler.betas = betas
|
62 |
+
noise_scheduler.alphas = alphas
|
63 |
+
noise_scheduler.alphas_cumprod = alphas_cumprod
|
64 |
+
|
65 |
+
|
66 |
+
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
|
67 |
+
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
68 |
+
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
|
69 |
+
if v_prediction:
|
70 |
+
snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
|
71 |
+
else:
|
72 |
+
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
|
73 |
+
loss = loss * snr_weight
|
74 |
+
return loss
|
75 |
+
|
76 |
+
|
77 |
+
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
|
78 |
+
scale = get_snr_scale(timesteps, noise_scheduler)
|
79 |
+
loss = loss * scale
|
80 |
+
return loss
|
81 |
+
|
82 |
+
|
83 |
+
def get_snr_scale(timesteps, noise_scheduler):
|
84 |
+
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
85 |
+
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
86 |
+
scale = snr_t / (snr_t + 1)
|
87 |
+
# # show debug info
|
88 |
+
# logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
|
89 |
+
return scale
|
90 |
+
|
91 |
+
|
92 |
+
def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
|
93 |
+
scale = get_snr_scale(timesteps, noise_scheduler)
|
94 |
+
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
|
95 |
+
loss = loss + loss / scale * v_pred_like_loss
|
96 |
+
return loss
|
97 |
+
|
98 |
+
|
99 |
+
def apply_debiased_estimation(loss, timesteps, noise_scheduler):
|
100 |
+
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
101 |
+
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
102 |
+
weight = 1 / torch.sqrt(snr_t)
|
103 |
+
loss = weight * loss
|
104 |
+
return loss
|
105 |
+
|
106 |
+
|
107 |
+
# TODO train_utilと分散しているのでどちらかに寄せる
|
108 |
+
|
109 |
+
|
110 |
+
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
|
111 |
+
parser.add_argument(
|
112 |
+
"--min_snr_gamma",
|
113 |
+
type=float,
|
114 |
+
default=None,
|
115 |
+
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
|
116 |
+
)
|
117 |
+
parser.add_argument(
|
118 |
+
"--scale_v_pred_loss_like_noise_pred",
|
119 |
+
action="store_true",
|
120 |
+
help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
|
121 |
+
)
|
122 |
+
parser.add_argument(
|
123 |
+
"--v_pred_like_loss",
|
124 |
+
type=float,
|
125 |
+
default=None,
|
126 |
+
help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
|
127 |
+
)
|
128 |
+
parser.add_argument(
|
129 |
+
"--debiased_estimation_loss",
|
130 |
+
action="store_true",
|
131 |
+
help="debiased estimation loss / debiased estimation loss",
|
132 |
+
)
|
133 |
+
if support_weighted_captions:
|
134 |
+
parser.add_argument(
|
135 |
+
"--weighted_captions",
|
136 |
+
action="store_true",
|
137 |
+
default=False,
|
138 |
+
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
|
139 |
+
)
|
140 |
+
|
141 |
+
|
142 |
+
re_attention = re.compile(
|
143 |
+
r"""
|
144 |
+
\\\(|
|
145 |
+
\\\)|
|
146 |
+
\\\[|
|
147 |
+
\\]|
|
148 |
+
\\\\|
|
149 |
+
\\|
|
150 |
+
\(|
|
151 |
+
\[|
|
152 |
+
:([+-]?[.\d]+)\)|
|
153 |
+
\)|
|
154 |
+
]|
|
155 |
+
[^\\()\[\]:]+|
|
156 |
+
:
|
157 |
+
""",
|
158 |
+
re.X,
|
159 |
+
)
|
160 |
+
|
161 |
+
|
162 |
+
def parse_prompt_attention(text):
|
163 |
+
"""
|
164 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
165 |
+
Accepted tokens are:
|
166 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
167 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
168 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
169 |
+
\( - literal character '('
|
170 |
+
\[ - literal character '['
|
171 |
+
\) - literal character ')'
|
172 |
+
\] - literal character ']'
|
173 |
+
\\ - literal character '\'
|
174 |
+
anything else - just text
|
175 |
+
>>> parse_prompt_attention('normal text')
|
176 |
+
[['normal text', 1.0]]
|
177 |
+
>>> parse_prompt_attention('an (important) word')
|
178 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
179 |
+
>>> parse_prompt_attention('(unbalanced')
|
180 |
+
[['unbalanced', 1.1]]
|
181 |
+
>>> parse_prompt_attention('\(literal\]')
|
182 |
+
[['(literal]', 1.0]]
|
183 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
184 |
+
[['unnecessaryparens', 1.1]]
|
185 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
186 |
+
[['a ', 1.0],
|
187 |
+
['house', 1.5730000000000004],
|
188 |
+
[' ', 1.1],
|
189 |
+
['on', 1.0],
|
190 |
+
[' a ', 1.1],
|
191 |
+
['hill', 0.55],
|
192 |
+
[', sun, ', 1.1],
|
193 |
+
['sky', 1.4641000000000006],
|
194 |
+
['.', 1.1]]
|
195 |
+
"""
|
196 |
+
|
197 |
+
res = []
|
198 |
+
round_brackets = []
|
199 |
+
square_brackets = []
|
200 |
+
|
201 |
+
round_bracket_multiplier = 1.1
|
202 |
+
square_bracket_multiplier = 1 / 1.1
|
203 |
+
|
204 |
+
def multiply_range(start_position, multiplier):
|
205 |
+
for p in range(start_position, len(res)):
|
206 |
+
res[p][1] *= multiplier
|
207 |
+
|
208 |
+
for m in re_attention.finditer(text):
|
209 |
+
text = m.group(0)
|
210 |
+
weight = m.group(1)
|
211 |
+
|
212 |
+
if text.startswith("\\"):
|
213 |
+
res.append([text[1:], 1.0])
|
214 |
+
elif text == "(":
|
215 |
+
round_brackets.append(len(res))
|
216 |
+
elif text == "[":
|
217 |
+
square_brackets.append(len(res))
|
218 |
+
elif weight is not None and len(round_brackets) > 0:
|
219 |
+
multiply_range(round_brackets.pop(), float(weight))
|
220 |
+
elif text == ")" and len(round_brackets) > 0:
|
221 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
222 |
+
elif text == "]" and len(square_brackets) > 0:
|
223 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
224 |
+
else:
|
225 |
+
res.append([text, 1.0])
|
226 |
+
|
227 |
+
for pos in round_brackets:
|
228 |
+
multiply_range(pos, round_bracket_multiplier)
|
229 |
+
|
230 |
+
for pos in square_brackets:
|
231 |
+
multiply_range(pos, square_bracket_multiplier)
|
232 |
+
|
233 |
+
if len(res) == 0:
|
234 |
+
res = [["", 1.0]]
|
235 |
+
|
236 |
+
# merge runs of identical weights
|
237 |
+
i = 0
|
238 |
+
while i + 1 < len(res):
|
239 |
+
if res[i][1] == res[i + 1][1]:
|
240 |
+
res[i][0] += res[i + 1][0]
|
241 |
+
res.pop(i + 1)
|
242 |
+
else:
|
243 |
+
i += 1
|
244 |
+
|
245 |
+
return res
|
246 |
+
|
247 |
+
|
248 |
+
def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
|
249 |
+
r"""
|
250 |
+
Tokenize a list of prompts and return its tokens with weights of each token.
|
251 |
+
|
252 |
+
No padding, starting or ending token is included.
|
253 |
+
"""
|
254 |
+
tokens = []
|
255 |
+
weights = []
|
256 |
+
truncated = False
|
257 |
+
for text in prompt:
|
258 |
+
texts_and_weights = parse_prompt_attention(text)
|
259 |
+
text_token = []
|
260 |
+
text_weight = []
|
261 |
+
for word, weight in texts_and_weights:
|
262 |
+
# tokenize and discard the starting and the ending token
|
263 |
+
token = tokenizer(word).input_ids[1:-1]
|
264 |
+
text_token += token
|
265 |
+
# copy the weight by length of token
|
266 |
+
text_weight += [weight] * len(token)
|
267 |
+
# stop if the text is too long (longer than truncation limit)
|
268 |
+
if len(text_token) > max_length:
|
269 |
+
truncated = True
|
270 |
+
break
|
271 |
+
# truncate
|
272 |
+
if len(text_token) > max_length:
|
273 |
+
truncated = True
|
274 |
+
text_token = text_token[:max_length]
|
275 |
+
text_weight = text_weight[:max_length]
|
276 |
+
tokens.append(text_token)
|
277 |
+
weights.append(text_weight)
|
278 |
+
if truncated:
|
279 |
+
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
280 |
+
return tokens, weights
|
281 |
+
|
282 |
+
|
283 |
+
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
284 |
+
r"""
|
285 |
+
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
286 |
+
"""
|
287 |
+
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
288 |
+
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
289 |
+
for i in range(len(tokens)):
|
290 |
+
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
291 |
+
if no_boseos_middle:
|
292 |
+
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
293 |
+
else:
|
294 |
+
w = []
|
295 |
+
if len(weights[i]) == 0:
|
296 |
+
w = [1.0] * weights_length
|
297 |
+
else:
|
298 |
+
for j in range(max_embeddings_multiples):
|
299 |
+
w.append(1.0) # weight for starting token in this chunk
|
300 |
+
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
301 |
+
w.append(1.0) # weight for ending token in this chunk
|
302 |
+
w += [1.0] * (weights_length - len(w))
|
303 |
+
weights[i] = w[:]
|
304 |
+
|
305 |
+
return tokens, weights
|
306 |
+
|
307 |
+
|
308 |
+
def get_unweighted_text_embeddings(
|
309 |
+
tokenizer,
|
310 |
+
text_encoder,
|
311 |
+
text_input: torch.Tensor,
|
312 |
+
chunk_length: int,
|
313 |
+
clip_skip: int,
|
314 |
+
eos: int,
|
315 |
+
pad: int,
|
316 |
+
no_boseos_middle: Optional[bool] = True,
|
317 |
+
):
|
318 |
+
"""
|
319 |
+
When the length of tokens is a multiple of the capacity of the text encoder,
|
320 |
+
it should be split into chunks and sent to the text encoder individually.
|
321 |
+
"""
|
322 |
+
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
323 |
+
if max_embeddings_multiples > 1:
|
324 |
+
text_embeddings = []
|
325 |
+
for i in range(max_embeddings_multiples):
|
326 |
+
# extract the i-th chunk
|
327 |
+
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
328 |
+
|
329 |
+
# cover the head and the tail by the starting and the ending tokens
|
330 |
+
text_input_chunk[:, 0] = text_input[0, 0]
|
331 |
+
if pad == eos: # v1
|
332 |
+
text_input_chunk[:, -1] = text_input[0, -1]
|
333 |
+
else: # v2
|
334 |
+
for j in range(len(text_input_chunk)):
|
335 |
+
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
336 |
+
text_input_chunk[j, -1] = eos
|
337 |
+
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
338 |
+
text_input_chunk[j, 1] = eos
|
339 |
+
|
340 |
+
if clip_skip is None or clip_skip == 1:
|
341 |
+
text_embedding = text_encoder(text_input_chunk)[0]
|
342 |
+
else:
|
343 |
+
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
|
344 |
+
text_embedding = enc_out["hidden_states"][-clip_skip]
|
345 |
+
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
|
346 |
+
|
347 |
+
if no_boseos_middle:
|
348 |
+
if i == 0:
|
349 |
+
# discard the ending token
|
350 |
+
text_embedding = text_embedding[:, :-1]
|
351 |
+
elif i == max_embeddings_multiples - 1:
|
352 |
+
# discard the starting token
|
353 |
+
text_embedding = text_embedding[:, 1:]
|
354 |
+
else:
|
355 |
+
# discard both starting and ending tokens
|
356 |
+
text_embedding = text_embedding[:, 1:-1]
|
357 |
+
|
358 |
+
text_embeddings.append(text_embedding)
|
359 |
+
text_embeddings = torch.concat(text_embeddings, axis=1)
|
360 |
+
else:
|
361 |
+
if clip_skip is None or clip_skip == 1:
|
362 |
+
text_embeddings = text_encoder(text_input)[0]
|
363 |
+
else:
|
364 |
+
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
365 |
+
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
366 |
+
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
|
367 |
+
return text_embeddings
|
368 |
+
|
369 |
+
|
370 |
+
def get_weighted_text_embeddings(
|
371 |
+
tokenizer,
|
372 |
+
text_encoder,
|
373 |
+
prompt: Union[str, List[str]],
|
374 |
+
device,
|
375 |
+
max_embeddings_multiples: Optional[int] = 3,
|
376 |
+
no_boseos_middle: Optional[bool] = False,
|
377 |
+
clip_skip=None,
|
378 |
+
):
|
379 |
+
r"""
|
380 |
+
Prompts can be assigned with local weights using brackets. For example,
|
381 |
+
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
382 |
+
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
383 |
+
|
384 |
+
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
385 |
+
|
386 |
+
Args:
|
387 |
+
prompt (`str` or `List[str]`):
|
388 |
+
The prompt or prompts to guide the image generation.
|
389 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
390 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
391 |
+
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
392 |
+
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
393 |
+
ending token in each of the chunk in the middle.
|
394 |
+
skip_parsing (`bool`, *optional*, defaults to `False`):
|
395 |
+
Skip the parsing of brackets.
|
396 |
+
skip_weighting (`bool`, *optional*, defaults to `False`):
|
397 |
+
Skip the weighting. When the parsing is skipped, it is forced True.
|
398 |
+
"""
|
399 |
+
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
400 |
+
if isinstance(prompt, str):
|
401 |
+
prompt = [prompt]
|
402 |
+
|
403 |
+
prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
|
404 |
+
|
405 |
+
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
406 |
+
max_length = max([len(token) for token in prompt_tokens])
|
407 |
+
|
408 |
+
max_embeddings_multiples = min(
|
409 |
+
max_embeddings_multiples,
|
410 |
+
(max_length - 1) // (tokenizer.model_max_length - 2) + 1,
|
411 |
+
)
|
412 |
+
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
413 |
+
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
414 |
+
|
415 |
+
# pad the length of tokens and weights
|
416 |
+
bos = tokenizer.bos_token_id
|
417 |
+
eos = tokenizer.eos_token_id
|
418 |
+
pad = tokenizer.pad_token_id
|
419 |
+
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
420 |
+
prompt_tokens,
|
421 |
+
prompt_weights,
|
422 |
+
max_length,
|
423 |
+
bos,
|
424 |
+
eos,
|
425 |
+
no_boseos_middle=no_boseos_middle,
|
426 |
+
chunk_length=tokenizer.model_max_length,
|
427 |
+
)
|
428 |
+
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
|
429 |
+
|
430 |
+
# get the embeddings
|
431 |
+
text_embeddings = get_unweighted_text_embeddings(
|
432 |
+
tokenizer,
|
433 |
+
text_encoder,
|
434 |
+
prompt_tokens,
|
435 |
+
tokenizer.model_max_length,
|
436 |
+
clip_skip,
|
437 |
+
eos,
|
438 |
+
pad,
|
439 |
+
no_boseos_middle=no_boseos_middle,
|
440 |
+
)
|
441 |
+
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
|
442 |
+
|
443 |
+
# assign weights to the prompts and normalize in the sense of mean
|
444 |
+
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
445 |
+
text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
|
446 |
+
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
447 |
+
text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
448 |
+
|
449 |
+
return text_embeddings
|
450 |
+
|
451 |
+
|
452 |
+
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
|
453 |
+
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
|
454 |
+
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
|
455 |
+
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
|
456 |
+
for i in range(iterations):
|
457 |
+
r = random.random() * 2 + 2 # Rather than always going 2x,
|
458 |
+
wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
|
459 |
+
noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
|
460 |
+
if wn == 1 or hn == 1:
|
461 |
+
break # Lowest resolution is 1x1
|
462 |
+
return noise / noise.std() # Scaled back to roughly unit variance
|
463 |
+
|
464 |
+
|
465 |
+
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
466 |
+
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
467 |
+
if noise_offset is None:
|
468 |
+
return noise
|
469 |
+
if adaptive_noise_scale is not None:
|
470 |
+
# latent shape: (batch_size, channels, height, width)
|
471 |
+
# abs mean value for each channel
|
472 |
+
latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
|
473 |
+
|
474 |
+
# multiply adaptive noise scale to the mean value and add it to the noise offset
|
475 |
+
noise_offset = noise_offset + adaptive_noise_scale * latent_mean
|
476 |
+
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
|
477 |
+
|
478 |
+
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
479 |
+
return noise
|
480 |
+
|
481 |
+
|
482 |
+
def apply_masked_loss(loss, batch):
|
483 |
+
if "conditioning_images" in batch:
|
484 |
+
# conditioning image is -1 to 1. we need to convert it to 0 to 1
|
485 |
+
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
|
486 |
+
mask_image = mask_image / 2 + 0.5
|
487 |
+
# print(f"conditioning_image: {mask_image.shape}")
|
488 |
+
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
|
489 |
+
# alpha mask is 0 to 1
|
490 |
+
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
|
491 |
+
# print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
|
492 |
+
else:
|
493 |
+
return loss
|
494 |
+
|
495 |
+
# resize to the same size as the loss
|
496 |
+
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
|
497 |
+
loss = loss * mask_image
|
498 |
+
return loss
|
499 |
+
|
500 |
+
|
501 |
+
"""
|
502 |
+
##########################################
|
503 |
+
# Perlin Noise
|
504 |
+
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
505 |
+
delta = (res[0] / shape[0], res[1] / shape[1])
|
506 |
+
d = (shape[0] // res[0], shape[1] // res[1])
|
507 |
+
|
508 |
+
grid = (
|
509 |
+
torch.stack(
|
510 |
+
torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
|
511 |
+
dim=-1,
|
512 |
+
)
|
513 |
+
% 1
|
514 |
+
)
|
515 |
+
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
|
516 |
+
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
|
517 |
+
|
518 |
+
tile_grads = (
|
519 |
+
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
|
520 |
+
.repeat_interleave(d[0], 0)
|
521 |
+
.repeat_interleave(d[1], 1)
|
522 |
+
)
|
523 |
+
dot = lambda grad, shift: (
|
524 |
+
torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
|
525 |
+
* grad[: shape[0], : shape[1]]
|
526 |
+
).sum(dim=-1)
|
527 |
+
|
528 |
+
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
529 |
+
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
530 |
+
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
|
531 |
+
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
|
532 |
+
t = fade(grid[: shape[0], : shape[1]])
|
533 |
+
return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
|
534 |
+
|
535 |
+
|
536 |
+
def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
|
537 |
+
noise = torch.zeros(shape, device=device)
|
538 |
+
frequency = 1
|
539 |
+
amplitude = 1
|
540 |
+
for _ in range(octaves):
|
541 |
+
noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
|
542 |
+
frequency *= 2
|
543 |
+
amplitude *= persistence
|
544 |
+
return noise
|
545 |
+
|
546 |
+
|
547 |
+
def perlin_noise(noise, device, octaves):
|
548 |
+
_, c, w, h = noise.shape
|
549 |
+
perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
|
550 |
+
noise_perlin = []
|
551 |
+
for _ in range(c):
|
552 |
+
noise_perlin.append(perlin())
|
553 |
+
noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
|
554 |
+
noise += noise_perlin # broadcast for each batch
|
555 |
+
return noise / noise.std() # Scaled back to roughly unit variance
|
556 |
+
"""
|
library/deepspeed_utils.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
from accelerate import DeepSpeedPlugin, Accelerator
|
5 |
+
|
6 |
+
from .utils import setup_logging
|
7 |
+
|
8 |
+
setup_logging()
|
9 |
+
import logging
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
def add_deepspeed_arguments(parser: argparse.ArgumentParser):
|
15 |
+
# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
|
16 |
+
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
|
17 |
+
parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.")
|
18 |
+
parser.add_argument(
|
19 |
+
"--offload_optimizer_device",
|
20 |
+
type=str,
|
21 |
+
default=None,
|
22 |
+
choices=[None, "cpu", "nvme"],
|
23 |
+
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.",
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--offload_optimizer_nvme_path",
|
27 |
+
type=str,
|
28 |
+
default=None,
|
29 |
+
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--offload_param_device",
|
33 |
+
type=str,
|
34 |
+
default=None,
|
35 |
+
choices=[None, "cpu", "nvme"],
|
36 |
+
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.",
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--offload_param_nvme_path",
|
40 |
+
type=str,
|
41 |
+
default=None,
|
42 |
+
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
|
43 |
+
)
|
44 |
+
parser.add_argument(
|
45 |
+
"--zero3_init_flag",
|
46 |
+
action="store_true",
|
47 |
+
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
|
48 |
+
"Only applicable with ZeRO Stage-3.",
|
49 |
+
)
|
50 |
+
parser.add_argument(
|
51 |
+
"--zero3_save_16bit_model",
|
52 |
+
action="store_true",
|
53 |
+
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.",
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
"--fp16_master_weights_and_gradients",
|
57 |
+
action="store_true",
|
58 |
+
help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.",
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
def prepare_deepspeed_args(args: argparse.Namespace):
|
63 |
+
if not args.deepspeed:
|
64 |
+
return
|
65 |
+
|
66 |
+
# To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
|
67 |
+
args.max_data_loader_n_workers = 1
|
68 |
+
|
69 |
+
|
70 |
+
def prepare_deepspeed_plugin(args: argparse.Namespace):
|
71 |
+
if not args.deepspeed:
|
72 |
+
return None
|
73 |
+
|
74 |
+
try:
|
75 |
+
import deepspeed
|
76 |
+
except ImportError as e:
|
77 |
+
logger.error(
|
78 |
+
"deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed"
|
79 |
+
)
|
80 |
+
exit(1)
|
81 |
+
|
82 |
+
deepspeed_plugin = DeepSpeedPlugin(
|
83 |
+
zero_stage=args.zero_stage,
|
84 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
85 |
+
gradient_clipping=args.max_grad_norm,
|
86 |
+
offload_optimizer_device=args.offload_optimizer_device,
|
87 |
+
offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
|
88 |
+
offload_param_device=args.offload_param_device,
|
89 |
+
offload_param_nvme_path=args.offload_param_nvme_path,
|
90 |
+
zero3_init_flag=args.zero3_init_flag,
|
91 |
+
zero3_save_16bit_model=args.zero3_save_16bit_model,
|
92 |
+
)
|
93 |
+
deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
|
94 |
+
deepspeed_plugin.deepspeed_config["train_batch_size"] = 1#(
|
95 |
+
# args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
|
96 |
+
#)
|
97 |
+
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
|
98 |
+
if args.mixed_precision.lower() == "fp16":
|
99 |
+
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
|
100 |
+
if args.full_fp16 or args.fp16_master_weights_and_gradients:
|
101 |
+
if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
|
102 |
+
deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True
|
103 |
+
logger.info("[DeepSpeed] full fp16 enable.")
|
104 |
+
else:
|
105 |
+
logger.info(
|
106 |
+
"[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage."
|
107 |
+
)
|
108 |
+
|
109 |
+
if args.offload_optimizer_device is not None:
|
110 |
+
logger.info("[DeepSpeed] start to manually build cpu_adam.")
|
111 |
+
deepspeed.ops.op_builder.CPUAdamBuilder().load()
|
112 |
+
logger.info("[DeepSpeed] building cpu_adam done.")
|
113 |
+
|
114 |
+
return deepspeed_plugin
|
115 |
+
|
116 |
+
|
117 |
+
# Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.
|
118 |
+
def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
119 |
+
# remove None from models
|
120 |
+
models = {k: v for k, v in models.items() if v is not None}
|
121 |
+
|
122 |
+
class DeepSpeedWrapper(torch.nn.Module):
|
123 |
+
def __init__(self, **kw_models) -> None:
|
124 |
+
super().__init__()
|
125 |
+
self.models = torch.nn.ModuleDict()
|
126 |
+
|
127 |
+
for key, model in kw_models.items():
|
128 |
+
if isinstance(model, list):
|
129 |
+
model = torch.nn.ModuleList(model)
|
130 |
+
assert isinstance(
|
131 |
+
model, torch.nn.Module
|
132 |
+
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
|
133 |
+
self.models.update(torch.nn.ModuleDict({key: model}))
|
134 |
+
|
135 |
+
def get_models(self):
|
136 |
+
return self.models
|
137 |
+
|
138 |
+
ds_model = DeepSpeedWrapper(**models)
|
139 |
+
return ds_model
|
library/device_utils.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import gc
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
try:
|
7 |
+
HAS_CUDA = torch.cuda.is_available()
|
8 |
+
except Exception:
|
9 |
+
HAS_CUDA = False
|
10 |
+
|
11 |
+
try:
|
12 |
+
HAS_MPS = torch.backends.mps.is_available()
|
13 |
+
except Exception:
|
14 |
+
HAS_MPS = False
|
15 |
+
|
16 |
+
try:
|
17 |
+
import intel_extension_for_pytorch as ipex # noqa
|
18 |
+
|
19 |
+
HAS_XPU = torch.xpu.is_available()
|
20 |
+
except Exception:
|
21 |
+
HAS_XPU = False
|
22 |
+
|
23 |
+
|
24 |
+
def clean_memory():
|
25 |
+
gc.collect()
|
26 |
+
if HAS_CUDA:
|
27 |
+
torch.cuda.empty_cache()
|
28 |
+
if HAS_XPU:
|
29 |
+
torch.xpu.empty_cache()
|
30 |
+
if HAS_MPS:
|
31 |
+
torch.mps.empty_cache()
|
32 |
+
|
33 |
+
|
34 |
+
def clean_memory_on_device(device: torch.device):
|
35 |
+
r"""
|
36 |
+
Clean memory on the specified device, will be called from training scripts.
|
37 |
+
"""
|
38 |
+
gc.collect()
|
39 |
+
|
40 |
+
# device may "cuda" or "cuda:0", so we need to check the type of device
|
41 |
+
if device.type == "cuda":
|
42 |
+
torch.cuda.empty_cache()
|
43 |
+
if device.type == "xpu":
|
44 |
+
torch.xpu.empty_cache()
|
45 |
+
if device.type == "mps":
|
46 |
+
torch.mps.empty_cache()
|
47 |
+
|
48 |
+
|
49 |
+
@functools.lru_cache(maxsize=None)
|
50 |
+
def get_preferred_device() -> torch.device:
|
51 |
+
r"""
|
52 |
+
Do not call this function from training scripts. Use accelerator.device instead.
|
53 |
+
"""
|
54 |
+
if HAS_CUDA:
|
55 |
+
device = torch.device("cuda")
|
56 |
+
elif HAS_XPU:
|
57 |
+
device = torch.device("xpu")
|
58 |
+
elif HAS_MPS:
|
59 |
+
device = torch.device("mps")
|
60 |
+
else:
|
61 |
+
device = torch.device("cpu")
|
62 |
+
print(f"get_preferred_device() -> {device}")
|
63 |
+
return device
|
64 |
+
|
65 |
+
|
66 |
+
def init_ipex():
|
67 |
+
"""
|
68 |
+
Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
|
69 |
+
|
70 |
+
This function should run right after importing torch and before doing anything else.
|
71 |
+
|
72 |
+
If IPEX is not available, this function does nothing.
|
73 |
+
"""
|
74 |
+
try:
|
75 |
+
if HAS_XPU:
|
76 |
+
from .ipex import ipex_init
|
77 |
+
|
78 |
+
is_initialized, error_message = ipex_init()
|
79 |
+
if not is_initialized:
|
80 |
+
print("failed to initialize ipex:", error_message)
|
81 |
+
else:
|
82 |
+
return
|
83 |
+
except Exception as e:
|
84 |
+
print("failed to initialize ipex:", e)
|
library/flux_models.py
ADDED
@@ -0,0 +1,1060 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copy from FLUX repo: https://github.com/black-forest-labs/flux
|
2 |
+
# license: Apache-2.0 License
|
3 |
+
|
4 |
+
from dataclasses import dataclass
|
5 |
+
import math
|
6 |
+
from typing import Dict, List, Optional, Union
|
7 |
+
|
8 |
+
from .device_utils import init_ipex
|
9 |
+
from .custom_offloading_utils import ModelOffloader
|
10 |
+
init_ipex()
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from einops import rearrange
|
14 |
+
from torch import Tensor, nn
|
15 |
+
from torch.utils.checkpoint import checkpoint
|
16 |
+
|
17 |
+
# USE_REENTRANT = True
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class FluxParams:
|
22 |
+
in_channels: int
|
23 |
+
vec_in_dim: int
|
24 |
+
context_in_dim: int
|
25 |
+
hidden_size: int
|
26 |
+
mlp_ratio: float
|
27 |
+
num_heads: int
|
28 |
+
depth: int
|
29 |
+
depth_single_blocks: int
|
30 |
+
axes_dim: list[int]
|
31 |
+
theta: int
|
32 |
+
qkv_bias: bool
|
33 |
+
guidance_embed: bool
|
34 |
+
|
35 |
+
|
36 |
+
# region autoencoder
|
37 |
+
|
38 |
+
|
39 |
+
@dataclass
|
40 |
+
class AutoEncoderParams:
|
41 |
+
resolution: int
|
42 |
+
in_channels: int
|
43 |
+
ch: int
|
44 |
+
out_ch: int
|
45 |
+
ch_mult: list[int]
|
46 |
+
num_res_blocks: int
|
47 |
+
z_channels: int
|
48 |
+
scale_factor: float
|
49 |
+
shift_factor: float
|
50 |
+
|
51 |
+
|
52 |
+
def swish(x: Tensor) -> Tensor:
|
53 |
+
return x * torch.sigmoid(x)
|
54 |
+
|
55 |
+
|
56 |
+
class AttnBlock(nn.Module):
|
57 |
+
def __init__(self, in_channels: int):
|
58 |
+
super().__init__()
|
59 |
+
self.in_channels = in_channels
|
60 |
+
|
61 |
+
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
62 |
+
|
63 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
64 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
65 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
66 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
67 |
+
|
68 |
+
def attention(self, h_: Tensor) -> Tensor:
|
69 |
+
h_ = self.norm(h_)
|
70 |
+
q = self.q(h_)
|
71 |
+
k = self.k(h_)
|
72 |
+
v = self.v(h_)
|
73 |
+
|
74 |
+
b, c, h, w = q.shape
|
75 |
+
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
76 |
+
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
77 |
+
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
78 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
79 |
+
|
80 |
+
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
81 |
+
|
82 |
+
def forward(self, x: Tensor) -> Tensor:
|
83 |
+
return x + self.proj_out(self.attention(x))
|
84 |
+
|
85 |
+
|
86 |
+
class ResnetBlock(nn.Module):
|
87 |
+
def __init__(self, in_channels: int, out_channels: int):
|
88 |
+
super().__init__()
|
89 |
+
self.in_channels = in_channels
|
90 |
+
out_channels = in_channels if out_channels is None else out_channels
|
91 |
+
self.out_channels = out_channels
|
92 |
+
|
93 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
94 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
95 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
96 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
97 |
+
if self.in_channels != self.out_channels:
|
98 |
+
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
99 |
+
|
100 |
+
def forward(self, x):
|
101 |
+
h = x
|
102 |
+
h = self.norm1(h)
|
103 |
+
h = swish(h)
|
104 |
+
h = self.conv1(h)
|
105 |
+
|
106 |
+
h = self.norm2(h)
|
107 |
+
h = swish(h)
|
108 |
+
h = self.conv2(h)
|
109 |
+
|
110 |
+
if self.in_channels != self.out_channels:
|
111 |
+
x = self.nin_shortcut(x)
|
112 |
+
|
113 |
+
return x + h
|
114 |
+
|
115 |
+
|
116 |
+
class Downsample(nn.Module):
|
117 |
+
def __init__(self, in_channels: int):
|
118 |
+
super().__init__()
|
119 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
120 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
121 |
+
|
122 |
+
def forward(self, x: Tensor):
|
123 |
+
pad = (0, 1, 0, 1)
|
124 |
+
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
125 |
+
x = self.conv(x)
|
126 |
+
return x
|
127 |
+
|
128 |
+
|
129 |
+
class Upsample(nn.Module):
|
130 |
+
def __init__(self, in_channels: int):
|
131 |
+
super().__init__()
|
132 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
133 |
+
|
134 |
+
def forward(self, x: Tensor):
|
135 |
+
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
136 |
+
x = self.conv(x)
|
137 |
+
return x
|
138 |
+
|
139 |
+
|
140 |
+
class Encoder(nn.Module):
|
141 |
+
def __init__(
|
142 |
+
self,
|
143 |
+
resolution: int,
|
144 |
+
in_channels: int,
|
145 |
+
ch: int,
|
146 |
+
ch_mult: list[int],
|
147 |
+
num_res_blocks: int,
|
148 |
+
z_channels: int,
|
149 |
+
):
|
150 |
+
super().__init__()
|
151 |
+
self.ch = ch
|
152 |
+
self.num_resolutions = len(ch_mult)
|
153 |
+
self.num_res_blocks = num_res_blocks
|
154 |
+
self.resolution = resolution
|
155 |
+
self.in_channels = in_channels
|
156 |
+
# downsampling
|
157 |
+
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
158 |
+
|
159 |
+
curr_res = resolution
|
160 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
161 |
+
self.in_ch_mult = in_ch_mult
|
162 |
+
self.down = nn.ModuleList()
|
163 |
+
block_in = self.ch
|
164 |
+
for i_level in range(self.num_resolutions):
|
165 |
+
block = nn.ModuleList()
|
166 |
+
attn = nn.ModuleList()
|
167 |
+
block_in = ch * in_ch_mult[i_level]
|
168 |
+
block_out = ch * ch_mult[i_level]
|
169 |
+
for _ in range(self.num_res_blocks):
|
170 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
171 |
+
block_in = block_out
|
172 |
+
down = nn.Module()
|
173 |
+
down.block = block
|
174 |
+
down.attn = attn
|
175 |
+
if i_level != self.num_resolutions - 1:
|
176 |
+
down.downsample = Downsample(block_in)
|
177 |
+
curr_res = curr_res // 2
|
178 |
+
self.down.append(down)
|
179 |
+
|
180 |
+
# middle
|
181 |
+
self.mid = nn.Module()
|
182 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
183 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
184 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
185 |
+
|
186 |
+
# end
|
187 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
188 |
+
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
189 |
+
|
190 |
+
def forward(self, x: Tensor) -> Tensor:
|
191 |
+
# downsampling
|
192 |
+
hs = [self.conv_in(x)]
|
193 |
+
for i_level in range(self.num_resolutions):
|
194 |
+
for i_block in range(self.num_res_blocks):
|
195 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
196 |
+
if len(self.down[i_level].attn) > 0:
|
197 |
+
h = self.down[i_level].attn[i_block](h)
|
198 |
+
hs.append(h)
|
199 |
+
if i_level != self.num_resolutions - 1:
|
200 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
201 |
+
|
202 |
+
# middle
|
203 |
+
h = hs[-1]
|
204 |
+
h = self.mid.block_1(h)
|
205 |
+
h = self.mid.attn_1(h)
|
206 |
+
h = self.mid.block_2(h)
|
207 |
+
# end
|
208 |
+
h = self.norm_out(h)
|
209 |
+
h = swish(h)
|
210 |
+
h = self.conv_out(h)
|
211 |
+
return h
|
212 |
+
|
213 |
+
|
214 |
+
class Decoder(nn.Module):
|
215 |
+
def __init__(
|
216 |
+
self,
|
217 |
+
ch: int,
|
218 |
+
out_ch: int,
|
219 |
+
ch_mult: list[int],
|
220 |
+
num_res_blocks: int,
|
221 |
+
in_channels: int,
|
222 |
+
resolution: int,
|
223 |
+
z_channels: int,
|
224 |
+
):
|
225 |
+
super().__init__()
|
226 |
+
self.ch = ch
|
227 |
+
self.num_resolutions = len(ch_mult)
|
228 |
+
self.num_res_blocks = num_res_blocks
|
229 |
+
self.resolution = resolution
|
230 |
+
self.in_channels = in_channels
|
231 |
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
232 |
+
|
233 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
234 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
235 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
236 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
237 |
+
|
238 |
+
# z to block_in
|
239 |
+
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
240 |
+
|
241 |
+
# middle
|
242 |
+
self.mid = nn.Module()
|
243 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
244 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
245 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
246 |
+
|
247 |
+
# upsampling
|
248 |
+
self.up = nn.ModuleList()
|
249 |
+
for i_level in reversed(range(self.num_resolutions)):
|
250 |
+
block = nn.ModuleList()
|
251 |
+
attn = nn.ModuleList()
|
252 |
+
block_out = ch * ch_mult[i_level]
|
253 |
+
for _ in range(self.num_res_blocks + 1):
|
254 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
255 |
+
block_in = block_out
|
256 |
+
up = nn.Module()
|
257 |
+
up.block = block
|
258 |
+
up.attn = attn
|
259 |
+
if i_level != 0:
|
260 |
+
up.upsample = Upsample(block_in)
|
261 |
+
curr_res = curr_res * 2
|
262 |
+
self.up.insert(0, up) # prepend to get consistent order
|
263 |
+
|
264 |
+
# end
|
265 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
266 |
+
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
267 |
+
|
268 |
+
def forward(self, z: Tensor) -> Tensor:
|
269 |
+
# z to block_in
|
270 |
+
h = self.conv_in(z)
|
271 |
+
|
272 |
+
# middle
|
273 |
+
h = self.mid.block_1(h)
|
274 |
+
h = self.mid.attn_1(h)
|
275 |
+
h = self.mid.block_2(h)
|
276 |
+
|
277 |
+
# upsampling
|
278 |
+
for i_level in reversed(range(self.num_resolutions)):
|
279 |
+
for i_block in range(self.num_res_blocks + 1):
|
280 |
+
h = self.up[i_level].block[i_block](h)
|
281 |
+
if len(self.up[i_level].attn) > 0:
|
282 |
+
h = self.up[i_level].attn[i_block](h)
|
283 |
+
if i_level != 0:
|
284 |
+
h = self.up[i_level].upsample(h)
|
285 |
+
|
286 |
+
# end
|
287 |
+
h = self.norm_out(h)
|
288 |
+
h = swish(h)
|
289 |
+
h = self.conv_out(h)
|
290 |
+
return h
|
291 |
+
|
292 |
+
|
293 |
+
class DiagonalGaussian(nn.Module):
|
294 |
+
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
295 |
+
super().__init__()
|
296 |
+
self.sample = sample
|
297 |
+
self.chunk_dim = chunk_dim
|
298 |
+
|
299 |
+
def forward(self, z: Tensor) -> Tensor:
|
300 |
+
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
301 |
+
if self.sample:
|
302 |
+
std = torch.exp(0.5 * logvar)
|
303 |
+
return mean + std * torch.randn_like(mean)
|
304 |
+
else:
|
305 |
+
return mean
|
306 |
+
|
307 |
+
|
308 |
+
class AutoEncoder(nn.Module):
|
309 |
+
def __init__(self, params: AutoEncoderParams):
|
310 |
+
super().__init__()
|
311 |
+
self.encoder = Encoder(
|
312 |
+
resolution=params.resolution,
|
313 |
+
in_channels=params.in_channels,
|
314 |
+
ch=params.ch,
|
315 |
+
ch_mult=params.ch_mult,
|
316 |
+
num_res_blocks=params.num_res_blocks,
|
317 |
+
z_channels=params.z_channels,
|
318 |
+
)
|
319 |
+
self.decoder = Decoder(
|
320 |
+
resolution=params.resolution,
|
321 |
+
in_channels=params.in_channels,
|
322 |
+
ch=params.ch,
|
323 |
+
out_ch=params.out_ch,
|
324 |
+
ch_mult=params.ch_mult,
|
325 |
+
num_res_blocks=params.num_res_blocks,
|
326 |
+
z_channels=params.z_channels,
|
327 |
+
)
|
328 |
+
self.reg = DiagonalGaussian()
|
329 |
+
|
330 |
+
self.scale_factor = params.scale_factor
|
331 |
+
self.shift_factor = params.shift_factor
|
332 |
+
|
333 |
+
@property
|
334 |
+
def device(self) -> torch.device:
|
335 |
+
return next(self.parameters()).device
|
336 |
+
|
337 |
+
@property
|
338 |
+
def dtype(self) -> torch.dtype:
|
339 |
+
return next(self.parameters()).dtype
|
340 |
+
|
341 |
+
def encode(self, x: Tensor) -> Tensor:
|
342 |
+
z = self.reg(self.encoder(x))
|
343 |
+
z = self.scale_factor * (z - self.shift_factor)
|
344 |
+
return z
|
345 |
+
|
346 |
+
def decode(self, z: Tensor) -> Tensor:
|
347 |
+
z = z / self.scale_factor + self.shift_factor
|
348 |
+
return self.decoder(z)
|
349 |
+
|
350 |
+
def forward(self, x: Tensor) -> Tensor:
|
351 |
+
return self.decode(self.encode(x))
|
352 |
+
|
353 |
+
|
354 |
+
# endregion
|
355 |
+
# region config
|
356 |
+
|
357 |
+
|
358 |
+
@dataclass
|
359 |
+
class ModelSpec:
|
360 |
+
params: FluxParams
|
361 |
+
ae_params: AutoEncoderParams
|
362 |
+
ckpt_path: str | None
|
363 |
+
ae_path: str | None
|
364 |
+
# repo_id: str | None
|
365 |
+
# repo_flow: str | None
|
366 |
+
# repo_ae: str | None
|
367 |
+
|
368 |
+
|
369 |
+
configs = {
|
370 |
+
"dev": ModelSpec(
|
371 |
+
# repo_id="black-forest-labs/FLUX.1-dev",
|
372 |
+
# repo_flow="flux1-dev.sft",
|
373 |
+
# repo_ae="ae.sft",
|
374 |
+
ckpt_path=None, # os.getenv("FLUX_DEV"),
|
375 |
+
params=FluxParams(
|
376 |
+
in_channels=64,
|
377 |
+
vec_in_dim=768,
|
378 |
+
context_in_dim=4096,
|
379 |
+
hidden_size=3072,
|
380 |
+
mlp_ratio=4.0,
|
381 |
+
num_heads=24,
|
382 |
+
depth=19,
|
383 |
+
depth_single_blocks=38,
|
384 |
+
axes_dim=[16, 56, 56],
|
385 |
+
theta=10_000,
|
386 |
+
qkv_bias=True,
|
387 |
+
guidance_embed=True,
|
388 |
+
),
|
389 |
+
ae_path=None, # os.getenv("AE"),
|
390 |
+
ae_params=AutoEncoderParams(
|
391 |
+
resolution=256,
|
392 |
+
in_channels=3,
|
393 |
+
ch=128,
|
394 |
+
out_ch=3,
|
395 |
+
ch_mult=[1, 2, 4, 4],
|
396 |
+
num_res_blocks=2,
|
397 |
+
z_channels=16,
|
398 |
+
scale_factor=0.3611,
|
399 |
+
shift_factor=0.1159,
|
400 |
+
),
|
401 |
+
),
|
402 |
+
"schnell": ModelSpec(
|
403 |
+
# repo_id="black-forest-labs/FLUX.1-schnell",
|
404 |
+
# repo_flow="flux1-schnell.sft",
|
405 |
+
# repo_ae="ae.sft",
|
406 |
+
ckpt_path=None, # os.getenv("FLUX_SCHNELL"),
|
407 |
+
params=FluxParams(
|
408 |
+
in_channels=64,
|
409 |
+
vec_in_dim=768,
|
410 |
+
context_in_dim=4096,
|
411 |
+
hidden_size=3072,
|
412 |
+
mlp_ratio=4.0,
|
413 |
+
num_heads=24,
|
414 |
+
depth=19,
|
415 |
+
depth_single_blocks=38,
|
416 |
+
axes_dim=[16, 56, 56],
|
417 |
+
theta=10_000,
|
418 |
+
qkv_bias=True,
|
419 |
+
guidance_embed=False,
|
420 |
+
),
|
421 |
+
ae_path=None, # os.getenv("AE"),
|
422 |
+
ae_params=AutoEncoderParams(
|
423 |
+
resolution=256,
|
424 |
+
in_channels=3,
|
425 |
+
ch=128,
|
426 |
+
out_ch=3,
|
427 |
+
ch_mult=[1, 2, 4, 4],
|
428 |
+
num_res_blocks=2,
|
429 |
+
z_channels=16,
|
430 |
+
scale_factor=0.3611,
|
431 |
+
shift_factor=0.1159,
|
432 |
+
),
|
433 |
+
),
|
434 |
+
}
|
435 |
+
|
436 |
+
|
437 |
+
# endregion
|
438 |
+
|
439 |
+
# region math
|
440 |
+
|
441 |
+
|
442 |
+
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
|
443 |
+
q, k = apply_rope(q, k, pe)
|
444 |
+
|
445 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
446 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
447 |
+
|
448 |
+
return x
|
449 |
+
|
450 |
+
|
451 |
+
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
452 |
+
assert dim % 2 == 0
|
453 |
+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
454 |
+
omega = 1.0 / (theta**scale)
|
455 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
456 |
+
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
457 |
+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
458 |
+
return out.float()
|
459 |
+
|
460 |
+
|
461 |
+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
462 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
463 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
464 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
465 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
466 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
467 |
+
|
468 |
+
|
469 |
+
# endregion
|
470 |
+
|
471 |
+
|
472 |
+
# region layers
|
473 |
+
|
474 |
+
|
475 |
+
# for cpu_offload_checkpointing
|
476 |
+
|
477 |
+
|
478 |
+
def to_cuda(x):
|
479 |
+
if isinstance(x, torch.Tensor):
|
480 |
+
return x.cuda()
|
481 |
+
elif isinstance(x, (list, tuple)):
|
482 |
+
return [to_cuda(elem) for elem in x]
|
483 |
+
elif isinstance(x, dict):
|
484 |
+
return {k: to_cuda(v) for k, v in x.items()}
|
485 |
+
else:
|
486 |
+
return x
|
487 |
+
|
488 |
+
|
489 |
+
def to_cpu(x):
|
490 |
+
if isinstance(x, torch.Tensor):
|
491 |
+
return x.cpu()
|
492 |
+
elif isinstance(x, (list, tuple)):
|
493 |
+
return [to_cpu(elem) for elem in x]
|
494 |
+
elif isinstance(x, dict):
|
495 |
+
return {k: to_cpu(v) for k, v in x.items()}
|
496 |
+
else:
|
497 |
+
return x
|
498 |
+
|
499 |
+
|
500 |
+
class EmbedND(nn.Module):
|
501 |
+
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
502 |
+
super().__init__()
|
503 |
+
self.dim = dim
|
504 |
+
self.theta = theta
|
505 |
+
self.axes_dim = axes_dim
|
506 |
+
|
507 |
+
def forward(self, ids: Tensor) -> Tensor:
|
508 |
+
n_axes = ids.shape[-1]
|
509 |
+
emb = torch.cat(
|
510 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
511 |
+
dim=-3,
|
512 |
+
)
|
513 |
+
|
514 |
+
return emb.unsqueeze(1)
|
515 |
+
|
516 |
+
|
517 |
+
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
518 |
+
"""
|
519 |
+
Create sinusoidal timestep embeddings.
|
520 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
521 |
+
These may be fractional.
|
522 |
+
:param dim: the dimension of the output.
|
523 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
524 |
+
:return: an (N, D) Tensor of positional embeddings.
|
525 |
+
"""
|
526 |
+
t = time_factor * t
|
527 |
+
half = dim // 2
|
528 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
|
529 |
+
|
530 |
+
args = t[:, None].float() * freqs[None]
|
531 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
532 |
+
if dim % 2:
|
533 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
534 |
+
if torch.is_floating_point(t):
|
535 |
+
embedding = embedding.to(t)
|
536 |
+
return embedding
|
537 |
+
|
538 |
+
|
539 |
+
class MLPEmbedder(nn.Module):
|
540 |
+
def __init__(self, in_dim: int, hidden_dim: int):
|
541 |
+
super().__init__()
|
542 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
543 |
+
self.silu = nn.SiLU()
|
544 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
545 |
+
|
546 |
+
self.gradient_checkpointing = False
|
547 |
+
|
548 |
+
def enable_gradient_checkpointing(self):
|
549 |
+
self.gradient_checkpointing = True
|
550 |
+
|
551 |
+
def disable_gradient_checkpointing(self):
|
552 |
+
self.gradient_checkpointing = False
|
553 |
+
|
554 |
+
def _forward(self, x: Tensor) -> Tensor:
|
555 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
556 |
+
|
557 |
+
def forward(self, *args, **kwargs):
|
558 |
+
if self.training and self.gradient_checkpointing:
|
559 |
+
return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
|
560 |
+
else:
|
561 |
+
return self._forward(*args, **kwargs)
|
562 |
+
|
563 |
+
# def forward(self, x):
|
564 |
+
# if self.training and self.gradient_checkpointing:
|
565 |
+
# def create_custom_forward(func):
|
566 |
+
# def custom_forward(*inputs):
|
567 |
+
# return func(*inputs)
|
568 |
+
# return custom_forward
|
569 |
+
# return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, use_reentrant=USE_REENTRANT)
|
570 |
+
# else:
|
571 |
+
# return self._forward(x)
|
572 |
+
|
573 |
+
|
574 |
+
class RMSNorm(torch.nn.Module):
|
575 |
+
def __init__(self, dim: int):
|
576 |
+
super().__init__()
|
577 |
+
self.scale = nn.Parameter(torch.ones(dim))
|
578 |
+
|
579 |
+
def forward(self, x: Tensor):
|
580 |
+
x_dtype = x.dtype
|
581 |
+
x = x.float()
|
582 |
+
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
583 |
+
# return (x * rrms).to(dtype=x_dtype) * self.scale
|
584 |
+
return ((x * rrms) * self.scale.float()).to(dtype=x_dtype)
|
585 |
+
|
586 |
+
|
587 |
+
class QKNorm(torch.nn.Module):
|
588 |
+
def __init__(self, dim: int):
|
589 |
+
super().__init__()
|
590 |
+
self.query_norm = RMSNorm(dim)
|
591 |
+
self.key_norm = RMSNorm(dim)
|
592 |
+
|
593 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
594 |
+
q = self.query_norm(q)
|
595 |
+
k = self.key_norm(k)
|
596 |
+
return q.to(v), k.to(v)
|
597 |
+
|
598 |
+
|
599 |
+
class SelfAttention(nn.Module):
|
600 |
+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
601 |
+
super().__init__()
|
602 |
+
self.num_heads = num_heads
|
603 |
+
head_dim = dim // num_heads
|
604 |
+
|
605 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
606 |
+
self.norm = QKNorm(head_dim)
|
607 |
+
self.proj = nn.Linear(dim, dim)
|
608 |
+
|
609 |
+
# this is not called from DoubleStreamBlock/SingleStreamBlock because they uses attention function directly
|
610 |
+
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
611 |
+
qkv = self.qkv(x)
|
612 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
613 |
+
q, k = self.norm(q, k, v)
|
614 |
+
x = attention(q, k, v, pe=pe)
|
615 |
+
x = self.proj(x)
|
616 |
+
return x
|
617 |
+
|
618 |
+
|
619 |
+
@dataclass
|
620 |
+
class ModulationOut:
|
621 |
+
shift: Tensor
|
622 |
+
scale: Tensor
|
623 |
+
gate: Tensor
|
624 |
+
|
625 |
+
|
626 |
+
class Modulation(nn.Module):
|
627 |
+
def __init__(self, dim: int, double: bool):
|
628 |
+
super().__init__()
|
629 |
+
self.is_double = double
|
630 |
+
self.multiplier = 6 if double else 3
|
631 |
+
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
632 |
+
|
633 |
+
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
634 |
+
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
635 |
+
|
636 |
+
return (
|
637 |
+
ModulationOut(*out[:3]),
|
638 |
+
ModulationOut(*out[3:]) if self.is_double else None,
|
639 |
+
)
|
640 |
+
|
641 |
+
|
642 |
+
class DoubleStreamBlock(nn.Module):
|
643 |
+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
|
644 |
+
super().__init__()
|
645 |
+
|
646 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
647 |
+
self.num_heads = num_heads
|
648 |
+
self.hidden_size = hidden_size
|
649 |
+
self.img_mod = Modulation(hidden_size, double=True)
|
650 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
651 |
+
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
652 |
+
|
653 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
654 |
+
self.img_mlp = nn.Sequential(
|
655 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
656 |
+
nn.GELU(approximate="tanh"),
|
657 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
658 |
+
)
|
659 |
+
|
660 |
+
self.txt_mod = Modulation(hidden_size, double=True)
|
661 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
662 |
+
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
663 |
+
|
664 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
665 |
+
self.txt_mlp = nn.Sequential(
|
666 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
667 |
+
nn.GELU(approximate="tanh"),
|
668 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
669 |
+
)
|
670 |
+
|
671 |
+
self.gradient_checkpointing = False
|
672 |
+
self.cpu_offload_checkpointing = False
|
673 |
+
|
674 |
+
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
|
675 |
+
self.gradient_checkpointing = True
|
676 |
+
self.cpu_offload_checkpointing = cpu_offload
|
677 |
+
|
678 |
+
def disable_gradient_checkpointing(self):
|
679 |
+
self.gradient_checkpointing = False
|
680 |
+
self.cpu_offload_checkpointing = False
|
681 |
+
|
682 |
+
def _forward(
|
683 |
+
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None
|
684 |
+
) -> tuple[Tensor, Tensor]:
|
685 |
+
img_mod1, img_mod2 = self.img_mod(vec)
|
686 |
+
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
687 |
+
|
688 |
+
# prepare image for attention
|
689 |
+
img_modulated = self.img_norm1(img)
|
690 |
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
691 |
+
img_qkv = self.img_attn.qkv(img_modulated)
|
692 |
+
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
693 |
+
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
694 |
+
|
695 |
+
# prepare txt for attention
|
696 |
+
txt_modulated = self.txt_norm1(txt)
|
697 |
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
698 |
+
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
699 |
+
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
700 |
+
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
701 |
+
|
702 |
+
# run actual attention
|
703 |
+
q = torch.cat((txt_q, img_q), dim=2)
|
704 |
+
k = torch.cat((txt_k, img_k), dim=2)
|
705 |
+
v = torch.cat((txt_v, img_v), dim=2)
|
706 |
+
|
707 |
+
# make attention mask if not None
|
708 |
+
attn_mask = None
|
709 |
+
if txt_attention_mask is not None:
|
710 |
+
# F.scaled_dot_product_attention expects attn_mask to be bool for binary mask
|
711 |
+
attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len
|
712 |
+
attn_mask = torch.cat(
|
713 |
+
(attn_mask, torch.ones(attn_mask.shape[0], img.shape[1], device=attn_mask.device, dtype=torch.bool)), dim=1
|
714 |
+
) # b, seq_len + img_len
|
715 |
+
|
716 |
+
# broadcast attn_mask to all heads
|
717 |
+
attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
|
718 |
+
|
719 |
+
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
720 |
+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
721 |
+
|
722 |
+
# calculate the img blocks
|
723 |
+
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
724 |
+
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
725 |
+
|
726 |
+
# calculate the txt blocks
|
727 |
+
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
728 |
+
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
729 |
+
return img, txt
|
730 |
+
|
731 |
+
def forward(
|
732 |
+
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None
|
733 |
+
) -> tuple[Tensor, Tensor]:
|
734 |
+
if self.training and self.gradient_checkpointing:
|
735 |
+
if not self.cpu_offload_checkpointing:
|
736 |
+
return checkpoint(self._forward, img, txt, vec, pe, txt_attention_mask, use_reentrant=False)
|
737 |
+
# cpu offload checkpointing
|
738 |
+
|
739 |
+
def create_custom_forward(func):
|
740 |
+
def custom_forward(*inputs):
|
741 |
+
cuda_inputs = to_cuda(inputs)
|
742 |
+
outputs = func(*cuda_inputs)
|
743 |
+
return to_cpu(outputs)
|
744 |
+
|
745 |
+
return custom_forward
|
746 |
+
|
747 |
+
return torch.utils.checkpoint.checkpoint(
|
748 |
+
create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask, use_reentrant=False
|
749 |
+
)
|
750 |
+
|
751 |
+
else:
|
752 |
+
return self._forward(img, txt, vec, pe, txt_attention_mask)
|
753 |
+
|
754 |
+
|
755 |
+
class SingleStreamBlock(nn.Module):
|
756 |
+
"""
|
757 |
+
A DiT block with parallel linear layers as described in
|
758 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
759 |
+
"""
|
760 |
+
|
761 |
+
def __init__(
|
762 |
+
self,
|
763 |
+
hidden_size: int,
|
764 |
+
num_heads: int,
|
765 |
+
mlp_ratio: float = 4.0,
|
766 |
+
qk_scale: float | None = None,
|
767 |
+
):
|
768 |
+
super().__init__()
|
769 |
+
self.hidden_dim = hidden_size
|
770 |
+
self.num_heads = num_heads
|
771 |
+
head_dim = hidden_size // num_heads
|
772 |
+
self.scale = qk_scale or head_dim**-0.5
|
773 |
+
|
774 |
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
775 |
+
# qkv and mlp_in
|
776 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
777 |
+
# proj and mlp_out
|
778 |
+
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
779 |
+
|
780 |
+
self.norm = QKNorm(head_dim)
|
781 |
+
|
782 |
+
self.hidden_size = hidden_size
|
783 |
+
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
784 |
+
|
785 |
+
self.mlp_act = nn.GELU(approximate="tanh")
|
786 |
+
self.modulation = Modulation(hidden_size, double=False)
|
787 |
+
|
788 |
+
self.gradient_checkpointing = False
|
789 |
+
self.cpu_offload_checkpointing = False
|
790 |
+
|
791 |
+
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
|
792 |
+
self.gradient_checkpointing = True
|
793 |
+
self.cpu_offload_checkpointing = cpu_offload
|
794 |
+
|
795 |
+
def disable_gradient_checkpointing(self):
|
796 |
+
self.gradient_checkpointing = False
|
797 |
+
self.cpu_offload_checkpointing = False
|
798 |
+
|
799 |
+
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
|
800 |
+
mod, _ = self.modulation(vec)
|
801 |
+
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
802 |
+
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
803 |
+
|
804 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
805 |
+
q, k = self.norm(q, k, v)
|
806 |
+
|
807 |
+
# make attention mask if not None
|
808 |
+
attn_mask = None
|
809 |
+
if txt_attention_mask is not None:
|
810 |
+
# F.scaled_dot_product_attention expects attn_mask to be bool for binary mask
|
811 |
+
attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len
|
812 |
+
attn_mask = torch.cat(
|
813 |
+
(
|
814 |
+
attn_mask,
|
815 |
+
torch.ones(
|
816 |
+
attn_mask.shape[0], x.shape[1] - txt_attention_mask.shape[1], device=attn_mask.device, dtype=torch.bool
|
817 |
+
),
|
818 |
+
),
|
819 |
+
dim=1,
|
820 |
+
) # b, seq_len + img_len = x_len
|
821 |
+
|
822 |
+
# broadcast attn_mask to all heads
|
823 |
+
attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
|
824 |
+
|
825 |
+
# compute attention
|
826 |
+
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
827 |
+
|
828 |
+
# compute activation in mlp stream, cat again and run second linear layer
|
829 |
+
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
830 |
+
return x + mod.gate * output
|
831 |
+
|
832 |
+
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
|
833 |
+
if self.training and self.gradient_checkpointing:
|
834 |
+
if not self.cpu_offload_checkpointing:
|
835 |
+
return checkpoint(self._forward, x, vec, pe, txt_attention_mask, use_reentrant=False)
|
836 |
+
|
837 |
+
# cpu offload checkpointing
|
838 |
+
|
839 |
+
def create_custom_forward(func):
|
840 |
+
def custom_forward(*inputs):
|
841 |
+
cuda_inputs = to_cuda(inputs)
|
842 |
+
outputs = func(*cuda_inputs)
|
843 |
+
return to_cpu(outputs)
|
844 |
+
|
845 |
+
return custom_forward
|
846 |
+
|
847 |
+
return torch.utils.checkpoint.checkpoint(
|
848 |
+
create_custom_forward(self._forward), x, vec, pe, txt_attention_mask, use_reentrant=False
|
849 |
+
)
|
850 |
+
else:
|
851 |
+
return self._forward(x, vec, pe, txt_attention_mask)
|
852 |
+
|
853 |
+
|
854 |
+
class LastLayer(nn.Module):
|
855 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
856 |
+
super().__init__()
|
857 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
858 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
859 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
860 |
+
|
861 |
+
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
862 |
+
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
863 |
+
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
864 |
+
x = self.linear(x)
|
865 |
+
return x
|
866 |
+
|
867 |
+
|
868 |
+
# endregion
|
869 |
+
|
870 |
+
|
871 |
+
class Flux(nn.Module):
|
872 |
+
"""
|
873 |
+
Transformer model for flow matching on sequences.
|
874 |
+
"""
|
875 |
+
|
876 |
+
def __init__(self, params: FluxParams):
|
877 |
+
super().__init__()
|
878 |
+
|
879 |
+
self.params = params
|
880 |
+
self.in_channels = params.in_channels
|
881 |
+
self.out_channels = self.in_channels
|
882 |
+
if params.hidden_size % params.num_heads != 0:
|
883 |
+
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
884 |
+
pe_dim = params.hidden_size // params.num_heads
|
885 |
+
if sum(params.axes_dim) != pe_dim:
|
886 |
+
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
887 |
+
self.hidden_size = params.hidden_size
|
888 |
+
self.num_heads = params.num_heads
|
889 |
+
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
890 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
891 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
892 |
+
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
893 |
+
self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
894 |
+
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
895 |
+
|
896 |
+
self.double_blocks = nn.ModuleList(
|
897 |
+
[
|
898 |
+
DoubleStreamBlock(
|
899 |
+
self.hidden_size,
|
900 |
+
self.num_heads,
|
901 |
+
mlp_ratio=params.mlp_ratio,
|
902 |
+
qkv_bias=params.qkv_bias,
|
903 |
+
)
|
904 |
+
for _ in range(params.depth)
|
905 |
+
]
|
906 |
+
)
|
907 |
+
|
908 |
+
self.single_blocks = nn.ModuleList(
|
909 |
+
[
|
910 |
+
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
911 |
+
for _ in range(params.depth_single_blocks)
|
912 |
+
]
|
913 |
+
)
|
914 |
+
|
915 |
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
916 |
+
|
917 |
+
self.gradient_checkpointing = False
|
918 |
+
self.cpu_offload_checkpointing = False
|
919 |
+
self.blocks_to_swap = None
|
920 |
+
|
921 |
+
self.offloader_double = None
|
922 |
+
self.offloader_single = None
|
923 |
+
self.num_double_blocks = len(self.double_blocks)
|
924 |
+
self.num_single_blocks = len(self.single_blocks)
|
925 |
+
|
926 |
+
@property
|
927 |
+
def device(self):
|
928 |
+
return next(self.parameters()).device
|
929 |
+
|
930 |
+
@property
|
931 |
+
def dtype(self):
|
932 |
+
return next(self.parameters()).dtype
|
933 |
+
|
934 |
+
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
|
935 |
+
self.gradient_checkpointing = True
|
936 |
+
self.cpu_offload_checkpointing = cpu_offload
|
937 |
+
|
938 |
+
self.time_in.enable_gradient_checkpointing()
|
939 |
+
self.vector_in.enable_gradient_checkpointing()
|
940 |
+
if self.guidance_in.__class__ != nn.Identity:
|
941 |
+
self.guidance_in.enable_gradient_checkpointing()
|
942 |
+
|
943 |
+
for block in self.double_blocks + self.single_blocks:
|
944 |
+
block.enable_gradient_checkpointing(cpu_offload=cpu_offload)
|
945 |
+
|
946 |
+
print(f"FLUX: Gradient checkpointing enabled. CPU offload: {cpu_offload}")
|
947 |
+
|
948 |
+
def disable_gradient_checkpointing(self):
|
949 |
+
self.gradient_checkpointing = False
|
950 |
+
self.cpu_offload_checkpointing = False
|
951 |
+
|
952 |
+
self.time_in.disable_gradient_checkpointing()
|
953 |
+
self.vector_in.disable_gradient_checkpointing()
|
954 |
+
if self.guidance_in.__class__ != nn.Identity:
|
955 |
+
self.guidance_in.disable_gradient_checkpointing()
|
956 |
+
|
957 |
+
for block in self.double_blocks + self.single_blocks:
|
958 |
+
block.disable_gradient_checkpointing()
|
959 |
+
|
960 |
+
print("FLUX: Gradient checkpointing disabled.")
|
961 |
+
|
962 |
+
def enable_block_swap(self, num_blocks: int, device: torch.device):
|
963 |
+
self.blocks_to_swap = num_blocks
|
964 |
+
double_blocks_to_swap = num_blocks // 2
|
965 |
+
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
|
966 |
+
|
967 |
+
assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, (
|
968 |
+
f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. "
|
969 |
+
f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
|
970 |
+
)
|
971 |
+
|
972 |
+
self.offloader_double = ModelOffloader(
|
973 |
+
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
|
974 |
+
)
|
975 |
+
self.offloader_single = ModelOffloader(
|
976 |
+
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
|
977 |
+
)
|
978 |
+
print(
|
979 |
+
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
|
980 |
+
)
|
981 |
+
|
982 |
+
def move_to_device_except_swap_blocks(self, device: torch.device):
|
983 |
+
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
|
984 |
+
if self.blocks_to_swap:
|
985 |
+
save_double_blocks = self.double_blocks
|
986 |
+
save_single_blocks = self.single_blocks
|
987 |
+
self.double_blocks = None
|
988 |
+
self.single_blocks = None
|
989 |
+
|
990 |
+
self.to(device)
|
991 |
+
|
992 |
+
if self.blocks_to_swap:
|
993 |
+
self.double_blocks = save_double_blocks
|
994 |
+
self.single_blocks = save_single_blocks
|
995 |
+
|
996 |
+
def prepare_block_swap_before_forward(self):
|
997 |
+
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
998 |
+
return
|
999 |
+
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
|
1000 |
+
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
|
1001 |
+
|
1002 |
+
def forward(
|
1003 |
+
self,
|
1004 |
+
img: Tensor,
|
1005 |
+
img_ids: Tensor,
|
1006 |
+
txt: Tensor,
|
1007 |
+
txt_ids: Tensor,
|
1008 |
+
timesteps: Tensor,
|
1009 |
+
y: Tensor,
|
1010 |
+
guidance: Tensor | None = None,
|
1011 |
+
txt_attention_mask: Tensor | None = None,
|
1012 |
+
) -> Tensor:
|
1013 |
+
if img.ndim != 3 or txt.ndim != 3:
|
1014 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
1015 |
+
|
1016 |
+
# running on sequences img
|
1017 |
+
img = self.img_in(img)
|
1018 |
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
1019 |
+
if self.params.guidance_embed:
|
1020 |
+
if guidance is None:
|
1021 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
1022 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
1023 |
+
vec = vec + self.vector_in(y)
|
1024 |
+
txt = self.txt_in(txt)
|
1025 |
+
|
1026 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
1027 |
+
pe = self.pe_embedder(ids)
|
1028 |
+
|
1029 |
+
if not self.blocks_to_swap:
|
1030 |
+
for block in self.double_blocks:
|
1031 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
1032 |
+
img = torch.cat((txt, img), 1)
|
1033 |
+
for block in self.single_blocks:
|
1034 |
+
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
1035 |
+
else:
|
1036 |
+
for block_idx, block in enumerate(self.double_blocks):
|
1037 |
+
self.offloader_double.wait_for_block(block_idx)
|
1038 |
+
|
1039 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
1040 |
+
|
1041 |
+
self.offloader_double.submit_move_blocks(self.double_blocks, block_idx)
|
1042 |
+
|
1043 |
+
img = torch.cat((txt, img), 1)
|
1044 |
+
|
1045 |
+
for block_idx, block in enumerate(self.single_blocks):
|
1046 |
+
self.offloader_single.wait_for_block(block_idx)
|
1047 |
+
|
1048 |
+
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
1049 |
+
|
1050 |
+
self.offloader_single.submit_move_blocks(self.single_blocks, block_idx)
|
1051 |
+
|
1052 |
+
img = img[:, txt.shape[1] :, ...]
|
1053 |
+
|
1054 |
+
if self.training and self.cpu_offload_checkpointing:
|
1055 |
+
img = img.to(self.device)
|
1056 |
+
vec = vec.to(self.device)
|
1057 |
+
|
1058 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
1059 |
+
|
1060 |
+
return img
|
library/flux_train_utils.py
ADDED
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
import toml
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from accelerate import Accelerator, PartialState
|
12 |
+
from transformers import CLIPTextModel
|
13 |
+
from tqdm import tqdm
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
from safetensors.torch import save_file
|
17 |
+
from . import flux_models, flux_utils, strategy_base, train_util
|
18 |
+
from .device_utils import init_ipex, clean_memory_on_device
|
19 |
+
|
20 |
+
init_ipex()
|
21 |
+
|
22 |
+
from .utils import setup_logging, mem_eff_save_file
|
23 |
+
|
24 |
+
setup_logging()
|
25 |
+
import logging
|
26 |
+
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
# from comfy.utils import ProgressBar
|
29 |
+
|
30 |
+
def sample_images(
|
31 |
+
accelerator: Accelerator,
|
32 |
+
args: argparse.Namespace,
|
33 |
+
epoch,
|
34 |
+
steps,
|
35 |
+
flux,
|
36 |
+
ae,
|
37 |
+
text_encoders,
|
38 |
+
sample_prompts_te_outputs,
|
39 |
+
validation_settings=None,
|
40 |
+
prompt_replacement=None,
|
41 |
+
):
|
42 |
+
|
43 |
+
logger.info("")
|
44 |
+
logger.info(f"generating sample images at step: {steps}")
|
45 |
+
|
46 |
+
#distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
|
47 |
+
|
48 |
+
# unwrap unet and text_encoder(s)
|
49 |
+
flux = accelerator.unwrap_model(flux)
|
50 |
+
if text_encoders is not None:
|
51 |
+
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
|
52 |
+
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
|
53 |
+
|
54 |
+
prompts = []
|
55 |
+
for line in args.sample_prompts:
|
56 |
+
line = line.strip()
|
57 |
+
if len(line) > 0 and line[0] != "#":
|
58 |
+
prompts.append(line)
|
59 |
+
|
60 |
+
# preprocess prompts
|
61 |
+
for i in range(len(prompts)):
|
62 |
+
prompt_dict = prompts[i]
|
63 |
+
if isinstance(prompt_dict, str):
|
64 |
+
from .train_util import line_to_prompt_dict
|
65 |
+
|
66 |
+
prompt_dict = line_to_prompt_dict(prompt_dict)
|
67 |
+
prompts[i] = prompt_dict
|
68 |
+
assert isinstance(prompt_dict, dict)
|
69 |
+
|
70 |
+
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
|
71 |
+
prompt_dict["enum"] = i
|
72 |
+
prompt_dict.pop("subset", None)
|
73 |
+
|
74 |
+
save_dir = args.output_dir + "/sample"
|
75 |
+
os.makedirs(save_dir, exist_ok=True)
|
76 |
+
|
77 |
+
# save random state to restore later
|
78 |
+
rng_state = torch.get_rng_state()
|
79 |
+
cuda_rng_state = None
|
80 |
+
try:
|
81 |
+
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
82 |
+
except Exception:
|
83 |
+
pass
|
84 |
+
|
85 |
+
with torch.no_grad(), accelerator.autocast():
|
86 |
+
image_tensor_list = []
|
87 |
+
for prompt_dict in prompts:
|
88 |
+
image_tensor = sample_image_inference(
|
89 |
+
accelerator,
|
90 |
+
args,
|
91 |
+
flux,
|
92 |
+
text_encoders,
|
93 |
+
ae,
|
94 |
+
save_dir,
|
95 |
+
prompt_dict,
|
96 |
+
epoch,
|
97 |
+
steps,
|
98 |
+
sample_prompts_te_outputs,
|
99 |
+
prompt_replacement,
|
100 |
+
validation_settings
|
101 |
+
)
|
102 |
+
image_tensor_list.append(image_tensor)
|
103 |
+
|
104 |
+
torch.set_rng_state(rng_state)
|
105 |
+
if cuda_rng_state is not None:
|
106 |
+
torch.cuda.set_rng_state(cuda_rng_state)
|
107 |
+
|
108 |
+
clean_memory_on_device(accelerator.device)
|
109 |
+
return torch.cat(image_tensor_list, dim=0)
|
110 |
+
|
111 |
+
|
112 |
+
def sample_image_inference(
|
113 |
+
accelerator: Accelerator,
|
114 |
+
args: argparse.Namespace,
|
115 |
+
flux: flux_models.Flux,
|
116 |
+
text_encoders: Optional[List[CLIPTextModel]],
|
117 |
+
ae: flux_models.AutoEncoder,
|
118 |
+
save_dir,
|
119 |
+
prompt_dict,
|
120 |
+
epoch,
|
121 |
+
steps,
|
122 |
+
sample_prompts_te_outputs,
|
123 |
+
prompt_replacement,
|
124 |
+
validation_settings=None
|
125 |
+
):
|
126 |
+
assert isinstance(prompt_dict, dict)
|
127 |
+
# negative_prompt = prompt_dict.get("negative_prompt")
|
128 |
+
if validation_settings is not None:
|
129 |
+
sample_steps = validation_settings["steps"]
|
130 |
+
width = validation_settings["width"]
|
131 |
+
height = validation_settings["height"]
|
132 |
+
scale = validation_settings["guidance_scale"]
|
133 |
+
seed = validation_settings["seed"]
|
134 |
+
base_shift = validation_settings["base_shift"]
|
135 |
+
max_shift = validation_settings["max_shift"]
|
136 |
+
shift = validation_settings["shift"]
|
137 |
+
else:
|
138 |
+
sample_steps = prompt_dict.get("sample_steps", 20)
|
139 |
+
width = prompt_dict.get("width", 512)
|
140 |
+
height = prompt_dict.get("height", 512)
|
141 |
+
scale = prompt_dict.get("scale", 3.5)
|
142 |
+
seed = prompt_dict.get("seed")
|
143 |
+
base_shift = 0.5
|
144 |
+
max_shift = 1.15
|
145 |
+
shift = True
|
146 |
+
# controlnet_image = prompt_dict.get("controlnet_image")
|
147 |
+
prompt: str = prompt_dict.get("prompt", "")
|
148 |
+
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
149 |
+
|
150 |
+
if prompt_replacement is not None:
|
151 |
+
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
152 |
+
# if negative_prompt is not None:
|
153 |
+
# negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
154 |
+
|
155 |
+
if seed is not None:
|
156 |
+
torch.manual_seed(seed)
|
157 |
+
torch.cuda.manual_seed(seed)
|
158 |
+
else:
|
159 |
+
# True random sample image generation
|
160 |
+
torch.seed()
|
161 |
+
torch.cuda.seed()
|
162 |
+
|
163 |
+
# if negative_prompt is None:
|
164 |
+
# negative_prompt = ""
|
165 |
+
|
166 |
+
height = max(64, height - height % 16) # round to divisible by 16
|
167 |
+
width = max(64, width - width % 16) # round to divisible by 16
|
168 |
+
logger.info(f"prompt: {prompt}")
|
169 |
+
# logger.info(f"negative_prompt: {negative_prompt}")
|
170 |
+
logger.info(f"height: {height}")
|
171 |
+
logger.info(f"width: {width}")
|
172 |
+
logger.info(f"sample_steps: {sample_steps}")
|
173 |
+
logger.info(f"scale: {scale}")
|
174 |
+
# logger.info(f"sample_sampler: {sampler_name}")
|
175 |
+
if seed is not None:
|
176 |
+
logger.info(f"seed: {seed}")
|
177 |
+
|
178 |
+
# encode prompts
|
179 |
+
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
180 |
+
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
181 |
+
|
182 |
+
text_encoder_conds = []
|
183 |
+
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
|
184 |
+
text_encoder_conds = sample_prompts_te_outputs[prompt]
|
185 |
+
print(f"Using cached text encoder outputs for prompt: {prompt}")
|
186 |
+
if text_encoders is not None:
|
187 |
+
print(f"Encoding prompt: {prompt}")
|
188 |
+
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
189 |
+
# strategy has apply_t5_attn_mask option
|
190 |
+
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
191 |
+
|
192 |
+
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
193 |
+
if len(text_encoder_conds) == 0:
|
194 |
+
text_encoder_conds = encoded_text_encoder_conds
|
195 |
+
else:
|
196 |
+
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
197 |
+
for i in range(len(encoded_text_encoder_conds)):
|
198 |
+
if encoded_text_encoder_conds[i] is not None:
|
199 |
+
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
200 |
+
|
201 |
+
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
202 |
+
|
203 |
+
# sample image
|
204 |
+
weight_dtype = ae.dtype # TOFO give dtype as argument
|
205 |
+
packed_latent_height = height // 16
|
206 |
+
packed_latent_width = width // 16
|
207 |
+
noise = torch.randn(
|
208 |
+
1,
|
209 |
+
packed_latent_height * packed_latent_width,
|
210 |
+
16 * 2 * 2,
|
211 |
+
device=accelerator.device,
|
212 |
+
dtype=weight_dtype,
|
213 |
+
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
|
214 |
+
)
|
215 |
+
timesteps = get_schedule(sample_steps, noise.shape[1], base_shift=base_shift, max_shift=max_shift, shift=shift) # FLUX.1 dev -> shift=True
|
216 |
+
#print("TIMESTEPS: ", timesteps)
|
217 |
+
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
|
218 |
+
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
|
219 |
+
|
220 |
+
with accelerator.autocast(), torch.no_grad():
|
221 |
+
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask)
|
222 |
+
|
223 |
+
x = x.float()
|
224 |
+
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
225 |
+
|
226 |
+
# latent to image
|
227 |
+
clean_memory_on_device(accelerator.device)
|
228 |
+
org_vae_device = ae.device # will be on cpu
|
229 |
+
ae.to(accelerator.device) # distributed_state.device is same as accelerator.device
|
230 |
+
with accelerator.autocast(), torch.no_grad():
|
231 |
+
x = ae.decode(x)
|
232 |
+
ae.to(org_vae_device)
|
233 |
+
clean_memory_on_device(accelerator.device)
|
234 |
+
|
235 |
+
x = x.clamp(-1, 1)
|
236 |
+
x = x.permute(0, 2, 3, 1)
|
237 |
+
image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
|
238 |
+
|
239 |
+
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
|
240 |
+
# but adding 'enum' to the filename should be enough
|
241 |
+
|
242 |
+
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
243 |
+
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
244 |
+
seed_suffix = "" if seed is None else f"_{seed}"
|
245 |
+
i: int = prompt_dict["enum"]
|
246 |
+
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
247 |
+
image.save(os.path.join(save_dir, img_filename))
|
248 |
+
return x
|
249 |
+
|
250 |
+
# wandb有効時のみログを送信
|
251 |
+
# try:
|
252 |
+
# wandb_tracker = accelerator.get_tracker("wandb")
|
253 |
+
# try:
|
254 |
+
# import wandb
|
255 |
+
# except ImportError: # 事前に一度確認するのでここはエラー出ないはず
|
256 |
+
# raise ImportError("No wandb / wandb がインストールされていないようです")
|
257 |
+
|
258 |
+
# wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
|
259 |
+
# except: # wandb 無効時
|
260 |
+
# pass
|
261 |
+
|
262 |
+
|
263 |
+
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
264 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
265 |
+
|
266 |
+
|
267 |
+
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
268 |
+
m = (y2 - y1) / (x2 - x1)
|
269 |
+
b = y1 - m * x1
|
270 |
+
return lambda x: m * x + b
|
271 |
+
|
272 |
+
|
273 |
+
def get_schedule(
|
274 |
+
num_steps: int,
|
275 |
+
image_seq_len: int,
|
276 |
+
base_shift: float = 0.5,
|
277 |
+
max_shift: float = 1.15,
|
278 |
+
shift: bool = True,
|
279 |
+
) -> list[float]:
|
280 |
+
# extra step for zero
|
281 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
282 |
+
|
283 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
284 |
+
if shift:
|
285 |
+
# eastimate mu based on linear estimation between two points
|
286 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
287 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
288 |
+
|
289 |
+
return timesteps.tolist()
|
290 |
+
|
291 |
+
|
292 |
+
def denoise(
|
293 |
+
model: flux_models.Flux,
|
294 |
+
img: torch.Tensor,
|
295 |
+
img_ids: torch.Tensor,
|
296 |
+
txt: torch.Tensor,
|
297 |
+
txt_ids: torch.Tensor,
|
298 |
+
vec: torch.Tensor,
|
299 |
+
timesteps: list[float],
|
300 |
+
guidance: float = 4.0,
|
301 |
+
t5_attn_mask: Optional[torch.Tensor] = None,
|
302 |
+
):
|
303 |
+
# this is ignored for schnell
|
304 |
+
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
305 |
+
# comfy_pbar = ProgressBar(total=len(timesteps))
|
306 |
+
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
307 |
+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
308 |
+
model.prepare_block_swap_before_forward()
|
309 |
+
pred = model(
|
310 |
+
img=img,
|
311 |
+
img_ids=img_ids,
|
312 |
+
txt=txt,
|
313 |
+
txt_ids=txt_ids,
|
314 |
+
y=vec,
|
315 |
+
timesteps=t_vec,
|
316 |
+
guidance=guidance_vec,
|
317 |
+
txt_attention_mask=t5_attn_mask,
|
318 |
+
)
|
319 |
+
|
320 |
+
img = img + (t_prev - t_curr) * pred
|
321 |
+
# comfy_pbar.update(1)
|
322 |
+
model.prepare_block_swap_before_forward()
|
323 |
+
return img
|
324 |
+
|
325 |
+
# endregion
|
326 |
+
|
327 |
+
|
328 |
+
# region train
|
329 |
+
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
|
330 |
+
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
|
331 |
+
schedule_timesteps = noise_scheduler.timesteps.to(device)
|
332 |
+
timesteps = timesteps.to(device)
|
333 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
334 |
+
|
335 |
+
sigma = sigmas[step_indices].flatten()
|
336 |
+
while len(sigma.shape) < n_dim:
|
337 |
+
sigma = sigma.unsqueeze(-1)
|
338 |
+
return sigma
|
339 |
+
|
340 |
+
|
341 |
+
def compute_density_for_timestep_sampling(
|
342 |
+
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
|
343 |
+
):
|
344 |
+
"""Compute the density for sampling the timesteps when doing SD3 training.
|
345 |
+
|
346 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
347 |
+
|
348 |
+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
349 |
+
"""
|
350 |
+
if weighting_scheme == "logit_normal":
|
351 |
+
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
352 |
+
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
|
353 |
+
u = torch.nn.functional.sigmoid(u)
|
354 |
+
elif weighting_scheme == "mode":
|
355 |
+
u = torch.rand(size=(batch_size,), device="cpu")
|
356 |
+
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
357 |
+
else:
|
358 |
+
u = torch.rand(size=(batch_size,), device="cpu")
|
359 |
+
return u
|
360 |
+
|
361 |
+
|
362 |
+
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
363 |
+
"""Computes loss weighting scheme for SD3 training.
|
364 |
+
|
365 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
366 |
+
|
367 |
+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
368 |
+
"""
|
369 |
+
if weighting_scheme == "sigma_sqrt":
|
370 |
+
weighting = (sigmas**-2.0).float()
|
371 |
+
elif weighting_scheme == "cosmap":
|
372 |
+
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
373 |
+
weighting = 2 / (math.pi * bot)
|
374 |
+
else:
|
375 |
+
weighting = torch.ones_like(sigmas)
|
376 |
+
return weighting
|
377 |
+
|
378 |
+
|
379 |
+
def get_noisy_model_input_and_timesteps(
|
380 |
+
args, noise_scheduler, latents, noise, device, dtype
|
381 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
382 |
+
bsz, _, H, W = latents.shape
|
383 |
+
sigmas = None
|
384 |
+
|
385 |
+
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
386 |
+
# Simple random t-based noise sampling
|
387 |
+
if args.timestep_sampling == "sigmoid":
|
388 |
+
# https://github.com/XLabs-AI/x-flux/tree/main
|
389 |
+
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
390 |
+
else:
|
391 |
+
t = torch.rand((bsz,), device=device)
|
392 |
+
|
393 |
+
timesteps = t * 1000.0
|
394 |
+
t = t.view(-1, 1, 1, 1)
|
395 |
+
noisy_model_input = (1 - t) * latents + t * noise
|
396 |
+
elif args.timestep_sampling == "shift":
|
397 |
+
shift = args.discrete_flow_shift
|
398 |
+
logits_norm = torch.randn(bsz, device=device)
|
399 |
+
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
400 |
+
timesteps = logits_norm.sigmoid()
|
401 |
+
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
|
402 |
+
|
403 |
+
t = timesteps.view(-1, 1, 1, 1)
|
404 |
+
timesteps = timesteps * 1000.0
|
405 |
+
noisy_model_input = (1 - t) * latents + t * noise
|
406 |
+
elif args.timestep_sampling == "flux_shift":
|
407 |
+
logits_norm = torch.randn(bsz, device=device)
|
408 |
+
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
409 |
+
timesteps = logits_norm.sigmoid()
|
410 |
+
mu=get_lin_function(y1=0.5, y2=1.15)((H//2) * (W//2))
|
411 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
412 |
+
|
413 |
+
t = timesteps.view(-1, 1, 1, 1)
|
414 |
+
timesteps = timesteps * 1000.0
|
415 |
+
noisy_model_input = (1 - t) * latents + t * noise
|
416 |
+
else:
|
417 |
+
# Sample a random timestep for each image
|
418 |
+
# for weighting schemes where we sample timesteps non-uniformly
|
419 |
+
u = compute_density_for_timestep_sampling(
|
420 |
+
weighting_scheme=args.weighting_scheme,
|
421 |
+
batch_size=bsz,
|
422 |
+
logit_mean=args.logit_mean,
|
423 |
+
logit_std=args.logit_std,
|
424 |
+
mode_scale=args.mode_scale,
|
425 |
+
)
|
426 |
+
indices = (u * noise_scheduler.config.num_train_timesteps).long()
|
427 |
+
timesteps = noise_scheduler.timesteps[indices].to(device=device)
|
428 |
+
|
429 |
+
# Add noise according to flow matching.
|
430 |
+
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
431 |
+
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
432 |
+
|
433 |
+
return noisy_model_input, timesteps, sigmas
|
434 |
+
|
435 |
+
|
436 |
+
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
|
437 |
+
weighting = None
|
438 |
+
if args.model_prediction_type == "raw":
|
439 |
+
pass
|
440 |
+
elif args.model_prediction_type == "additive":
|
441 |
+
# add the model_pred to the noisy_model_input
|
442 |
+
model_pred = model_pred + noisy_model_input
|
443 |
+
elif args.model_prediction_type == "sigma_scaled":
|
444 |
+
# apply sigma scaling
|
445 |
+
model_pred = model_pred * (-sigmas) + noisy_model_input
|
446 |
+
|
447 |
+
# these weighting schemes use a uniform timestep sampling
|
448 |
+
# and instead post-weight the loss
|
449 |
+
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
450 |
+
|
451 |
+
return model_pred, weighting
|
452 |
+
|
453 |
+
|
454 |
+
def save_models(
|
455 |
+
ckpt_path: str,
|
456 |
+
flux: flux_models.Flux,
|
457 |
+
sai_metadata: Optional[dict],
|
458 |
+
save_dtype: Optional[torch.dtype] = None,
|
459 |
+
use_mem_eff_save: bool = False,
|
460 |
+
):
|
461 |
+
state_dict = {}
|
462 |
+
|
463 |
+
def update_sd(prefix, sd):
|
464 |
+
for k, v in sd.items():
|
465 |
+
key = prefix + k
|
466 |
+
if save_dtype is not None and v.dtype != save_dtype:
|
467 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
468 |
+
state_dict[key] = v
|
469 |
+
|
470 |
+
update_sd("", flux.state_dict())
|
471 |
+
|
472 |
+
if not use_mem_eff_save:
|
473 |
+
save_file(state_dict, ckpt_path, metadata=sai_metadata)
|
474 |
+
else:
|
475 |
+
mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata)
|
476 |
+
|
477 |
+
|
478 |
+
def save_flux_model_on_train_end(
|
479 |
+
args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, flux: flux_models.Flux
|
480 |
+
):
|
481 |
+
def sd_saver(ckpt_file, epoch_no, global_step):
|
482 |
+
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
|
483 |
+
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
|
484 |
+
|
485 |
+
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
|
486 |
+
|
487 |
+
|
488 |
+
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
|
489 |
+
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
|
490 |
+
def save_flux_model_on_epoch_end_or_stepwise(
|
491 |
+
args: argparse.Namespace,
|
492 |
+
on_epoch_end: bool,
|
493 |
+
accelerator,
|
494 |
+
save_dtype: torch.dtype,
|
495 |
+
epoch: int,
|
496 |
+
num_train_epochs: int,
|
497 |
+
global_step: int,
|
498 |
+
flux: flux_models.Flux,
|
499 |
+
):
|
500 |
+
def sd_saver(ckpt_file, epoch_no, global_step):
|
501 |
+
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
|
502 |
+
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
|
503 |
+
|
504 |
+
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
505 |
+
args,
|
506 |
+
on_epoch_end,
|
507 |
+
accelerator,
|
508 |
+
True,
|
509 |
+
True,
|
510 |
+
epoch,
|
511 |
+
num_train_epochs,
|
512 |
+
global_step,
|
513 |
+
sd_saver,
|
514 |
+
None,
|
515 |
+
)
|
516 |
+
|
517 |
+
|
518 |
+
# endregion
|
519 |
+
|
520 |
+
|
521 |
+
def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
522 |
+
parser.add_argument(
|
523 |
+
"--clip_l",
|
524 |
+
type=str,
|
525 |
+
help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)、float16が前提",
|
526 |
+
)
|
527 |
+
parser.add_argument(
|
528 |
+
"--t5xxl",
|
529 |
+
type=str,
|
530 |
+
help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提",
|
531 |
+
)
|
532 |
+
parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)")
|
533 |
+
parser.add_argument(
|
534 |
+
"--t5xxl_max_token_length",
|
535 |
+
type=int,
|
536 |
+
default=None,
|
537 |
+
help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev"
|
538 |
+
" / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512",
|
539 |
+
)
|
540 |
+
parser.add_argument(
|
541 |
+
"--apply_t5_attn_mask",
|
542 |
+
action="store_true",
|
543 |
+
help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する",
|
544 |
+
)
|
545 |
+
|
546 |
+
parser.add_argument(
|
547 |
+
"--guidance_scale",
|
548 |
+
type=float,
|
549 |
+
default=3.5,
|
550 |
+
help="the FLUX.1 dev variant is a guidance distilled model",
|
551 |
+
)
|
552 |
+
|
553 |
+
parser.add_argument(
|
554 |
+
"--timestep_sampling",
|
555 |
+
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
|
556 |
+
default="sigma",
|
557 |
+
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting."
|
558 |
+
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。",
|
559 |
+
)
|
560 |
+
parser.add_argument(
|
561 |
+
"--sigmoid_scale",
|
562 |
+
type=float,
|
563 |
+
default=1.0,
|
564 |
+
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。',
|
565 |
+
)
|
566 |
+
parser.add_argument(
|
567 |
+
"--model_prediction_type",
|
568 |
+
choices=["raw", "additive", "sigma_scaled"],
|
569 |
+
default="sigma_scaled",
|
570 |
+
help="How to interpret and process the model prediction: "
|
571 |
+
"raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."
|
572 |
+
" / モデル予測の解釈と処理方法:"
|
573 |
+
"raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。",
|
574 |
+
)
|
575 |
+
parser.add_argument(
|
576 |
+
"--discrete_flow_shift",
|
577 |
+
type=float,
|
578 |
+
default=3.0,
|
579 |
+
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
|
580 |
+
)
|
581 |
+
parser.add_argument(
|
582 |
+
"--bypass_flux_guidance"
|
583 |
+
, action="store_true"
|
584 |
+
, help="bypass flux guidance module for Flex.1-Alpha Training"
|
585 |
+
)
|