|
import os |
|
import torch |
|
from torchvision import transforms |
|
|
|
import folder_paths |
|
import comfy.model_management as mm |
|
import comfy.utils |
|
import toml |
|
import json |
|
import time |
|
import shutil |
|
import shlex |
|
|
|
from pathlib import Path |
|
script_directory = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
from .flux_train_network_comfy import FluxNetworkTrainer |
|
from .library import flux_train_utils as flux_train_utils |
|
from .flux_train_comfy import FluxTrainer |
|
from .flux_train_comfy import setup_parser as train_setup_parser |
|
from .library.device_utils import init_ipex |
|
init_ipex() |
|
|
|
from .library import train_util |
|
from .train_network import setup_parser as train_network_setup_parser |
|
import matplotlib |
|
matplotlib.use('Agg') |
|
import matplotlib.pyplot as plt |
|
import io |
|
from PIL import Image |
|
|
|
import logging |
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
class FluxTrainModelSelect: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"transformer": (folder_paths.get_filename_list("unet"), ), |
|
"vae": (folder_paths.get_filename_list("vae"), ), |
|
"clip_l": (folder_paths.get_filename_list("clip"), ), |
|
"t5": (folder_paths.get_filename_list("clip"), ), |
|
}, |
|
"optional": { |
|
"lora_path": ("STRING",{"multiline": True, "forceInput": True, "default": "", "tooltip": "pre-trained LoRA path to load (network_weights)"}), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("TRAIN_FLUX_MODELS",) |
|
RETURN_NAMES = ("flux_models",) |
|
FUNCTION = "loadmodel" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def loadmodel(self, transformer, vae, clip_l, t5, lora_path=""): |
|
|
|
transformer_path = folder_paths.get_full_path("unet", transformer) |
|
vae_path = folder_paths.get_full_path("vae", vae) |
|
clip_path = folder_paths.get_full_path("clip", clip_l) |
|
t5_path = folder_paths.get_full_path("clip", t5) |
|
|
|
flux_models = { |
|
"transformer": transformer_path, |
|
"vae": vae_path, |
|
"clip_l": clip_path, |
|
"t5": t5_path, |
|
"lora_path": lora_path |
|
} |
|
|
|
return (flux_models,) |
|
|
|
class TrainDatasetGeneralConfig: |
|
queue_counter = 0 |
|
@classmethod |
|
def IS_CHANGED(s, reset_on_queue=False, **kwargs): |
|
if reset_on_queue: |
|
s.queue_counter += 1 |
|
print(f"queue_counter: {s.queue_counter}") |
|
return s.queue_counter |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"color_aug": ("BOOLEAN",{"default": False, "tooltip": "enable weak color augmentation"}), |
|
"flip_aug": ("BOOLEAN",{"default": False, "tooltip": "enable horizontal flip augmentation"}), |
|
"shuffle_caption": ("BOOLEAN",{"default": False, "tooltip": "shuffle caption"}), |
|
"caption_dropout_rate": ("FLOAT",{"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01,"tooltip": "tag dropout rate"}), |
|
"alpha_mask": ("BOOLEAN",{"default": False, "tooltip": "use alpha channel as mask for training"}), |
|
}, |
|
"optional": { |
|
"reset_on_queue": ("BOOLEAN",{"default": False, "tooltip": "Force refresh of everything for cleaner queueing"}), |
|
"caption_extension": ("STRING",{"default": ".txt", "tooltip": "extension for caption files"}), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("JSON",) |
|
RETURN_NAMES = ("dataset_general",) |
|
FUNCTION = "create_config" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def create_config(self, shuffle_caption, caption_dropout_rate, color_aug, flip_aug, alpha_mask, reset_on_queue=False, caption_extension=".txt"): |
|
|
|
dataset = { |
|
"general": { |
|
"shuffle_caption": shuffle_caption, |
|
"caption_extension": caption_extension, |
|
"keep_tokens_separator": "|||", |
|
"caption_dropout_rate": caption_dropout_rate, |
|
"color_aug": color_aug, |
|
"flip_aug": flip_aug, |
|
}, |
|
"datasets": [] |
|
} |
|
dataset_json = json.dumps(dataset, indent=2) |
|
|
|
dataset_config = { |
|
"datasets": dataset_json, |
|
"alpha_mask": alpha_mask |
|
} |
|
return (dataset_config,) |
|
|
|
class TrainDatasetRegularization: |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"dataset_path": ("STRING",{"multiline": True, "default": "", "tooltip": "path to dataset, root is the 'ComfyUI' folder, with windows portable 'ComfyUI_windows_portable'"}), |
|
"class_tokens": ("STRING",{"multiline": True, "default": "", "tooltip": "aka trigger word, if specified, will be added to the start of each caption, if no captions exist, will be used on it's own"}), |
|
"num_repeats": ("INT", {"default": 1, "min": 1, "tooltip": "number of times to repeat dataset for an epoch"}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("JSON",) |
|
RETURN_NAMES = ("subset",) |
|
FUNCTION = "create_config" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def create_config(self, dataset_path, class_tokens, num_repeats): |
|
|
|
reg_subset = { |
|
"image_dir": dataset_path, |
|
"class_tokens": class_tokens, |
|
"num_repeats": num_repeats, |
|
"is_reg": True |
|
} |
|
|
|
return reg_subset, |
|
|
|
class TrainDatasetAdd: |
|
def __init__(self): |
|
self.previous_dataset_signature = None |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"dataset_config": ("JSON",), |
|
"width": ("INT",{"min": 64, "default": 1024, "tooltip": "base resolution width"}), |
|
"height": ("INT",{"min": 64, "default": 1024, "tooltip": "base resolution height"}), |
|
"batch_size": ("INT",{"min": 1, "default": 2, "tooltip": "Higher batch size uses more memory and generalizes the training more"}), |
|
"dataset_path": ("STRING",{"multiline": True, "default": "", "tooltip": "path to dataset, root is the 'ComfyUI' folder, with windows portable 'ComfyUI_windows_portable'"}), |
|
"class_tokens": ("STRING",{"multiline": True, "default": "", "tooltip": "aka trigger word, if specified, will be added to the start of each caption, if no captions exist, will be used on it's own"}), |
|
"enable_bucket": ("BOOLEAN",{"default": True, "tooltip": "enable buckets for multi aspect ratio training"}), |
|
"bucket_no_upscale": ("BOOLEAN",{"default": False, "tooltip": "don't allow upscaling when bucketing"}), |
|
"num_repeats": ("INT", {"default": 1, "min": 1, "tooltip": "number of times to repeat dataset for an epoch"}), |
|
"min_bucket_reso": ("INT", {"default": 256, "min": 64, "max": 4096, "step": 8, "tooltip": "min bucket resolution"}), |
|
"max_bucket_reso": ("INT", {"default": 1024, "min": 64, "max": 4096, "step": 8, "tooltip": "max bucket resolution"}), |
|
}, |
|
"optional": { |
|
"regularization": ("JSON", {"tooltip": "reg data dir"}), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("JSON",) |
|
RETURN_NAMES = ("dataset",) |
|
FUNCTION = "create_config" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def create_config(self, dataset_config, dataset_path, class_tokens, width, height, batch_size, num_repeats, enable_bucket, |
|
bucket_no_upscale, min_bucket_reso, max_bucket_reso, regularization=None): |
|
|
|
new_dataset = { |
|
"resolution": (width, height), |
|
"batch_size": batch_size, |
|
"enable_bucket": enable_bucket, |
|
"bucket_no_upscale": bucket_no_upscale, |
|
"min_bucket_reso": min_bucket_reso, |
|
"max_bucket_reso": max_bucket_reso, |
|
"subsets": [ |
|
{ |
|
"image_dir": dataset_path, |
|
"class_tokens": class_tokens, |
|
"num_repeats": num_repeats |
|
} |
|
] |
|
} |
|
if regularization is not None: |
|
new_dataset["subsets"].append(regularization) |
|
|
|
|
|
new_dataset_signature = self.generate_signature(new_dataset) |
|
|
|
|
|
existing_datasets = json.loads(dataset_config["datasets"]) |
|
|
|
|
|
if self.previous_dataset_signature: |
|
existing_datasets["datasets"] = [ |
|
ds for ds in existing_datasets["datasets"] |
|
if self.generate_signature(ds) != self.previous_dataset_signature |
|
] |
|
|
|
|
|
existing_datasets["datasets"].append(new_dataset) |
|
|
|
|
|
self.previous_dataset_signature = new_dataset_signature |
|
|
|
|
|
updated_dataset_json = json.dumps(existing_datasets, indent=2) |
|
dataset_config["datasets"] = updated_dataset_json |
|
|
|
return dataset_config, |
|
|
|
def generate_signature(self, dataset): |
|
|
|
return json.dumps(dataset, sort_keys=True) |
|
|
|
class OptimizerConfig: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"optimizer_type": (["adamw8bit", "adamw","prodigy", "CAME", "Lion8bit", "Lion", "adamwschedulefree", "sgdschedulefree", "AdEMAMix8bit", "PagedAdEMAMix8bit", "ProdigyPlusScheduleFree"], {"default": "adamw8bit", "tooltip": "optimizer type"}), |
|
"max_grad_norm": ("FLOAT",{"default": 1.0, "min": 0.0, "tooltip": "gradient clipping"}), |
|
"lr_scheduler": (["constant", "cosine", "cosine_with_restarts", "polynomial", "constant_with_warmup"], {"default": "constant", "tooltip": "learning rate scheduler"}), |
|
"lr_warmup_steps": ("INT",{"default": 0, "min": 0, "tooltip": "learning rate warmup steps"}), |
|
"lr_scheduler_num_cycles": ("INT",{"default": 1, "min": 1, "tooltip": "learning rate scheduler num cycles"}), |
|
"lr_scheduler_power": ("FLOAT",{"default": 1.0, "min": 0.0, "tooltip": "learning rate scheduler power"}), |
|
"min_snr_gamma": ("FLOAT",{"default": 5.0, "min": 0.0, "step": 0.01, "tooltip": "gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by the paper"}), |
|
"extra_optimizer_args": ("STRING",{"multiline": True, "default": "", "tooltip": "additional optimizer args"}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("ARGS",) |
|
RETURN_NAMES = ("optimizer_settings",) |
|
FUNCTION = "create_config" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def create_config(self, min_snr_gamma, extra_optimizer_args, **kwargs): |
|
kwargs["min_snr_gamma"] = min_snr_gamma if min_snr_gamma != 0.0 else None |
|
kwargs["optimizer_args"] = [arg.strip() for arg in extra_optimizer_args.strip().split('|') if arg.strip()] |
|
return (kwargs,) |
|
|
|
class OptimizerConfigAdafactor: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"max_grad_norm": ("FLOAT",{"default": 0.0, "min": 0.0, "tooltip": "gradient clipping"}), |
|
"lr_scheduler": (["constant", "cosine", "cosine_with_restarts", "polynomial", "constant_with_warmup", "adafactor"], {"default": "constant_with_warmup", "tooltip": "learning rate scheduler"}), |
|
"lr_warmup_steps": ("INT",{"default": 0, "min": 0, "tooltip": "learning rate warmup steps"}), |
|
"lr_scheduler_num_cycles": ("INT",{"default": 1, "min": 1, "tooltip": "learning rate scheduler num cycles"}), |
|
"lr_scheduler_power": ("FLOAT",{"default": 1.0, "min": 0.0, "tooltip": "learning rate scheduler power"}), |
|
"relative_step": ("BOOLEAN",{"default": False, "tooltip": "relative step"}), |
|
"scale_parameter": ("BOOLEAN",{"default": False, "tooltip": "scale parameter"}), |
|
"warmup_init": ("BOOLEAN",{"default": False, "tooltip": "warmup init"}), |
|
"clip_threshold": ("FLOAT",{"default": 1.0, "min": 0.0, "tooltip": "clip threshold"}), |
|
"min_snr_gamma": ("FLOAT",{"default": 5.0, "min": 0.0, "step": 0.01, "tooltip": "gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by the paper"}), |
|
"extra_optimizer_args": ("STRING",{"multiline": True, "default": "", "tooltip": "additional optimizer args"}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("ARGS",) |
|
RETURN_NAMES = ("optimizer_settings",) |
|
FUNCTION = "create_config" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def create_config(self, relative_step, scale_parameter, warmup_init, clip_threshold, min_snr_gamma, extra_optimizer_args, **kwargs): |
|
kwargs["optimizer_type"] = "adafactor" |
|
extra_args = [arg.strip() for arg in extra_optimizer_args.strip().split('|') if arg.strip()] |
|
node_args = [ |
|
f"relative_step={relative_step}", |
|
f"scale_parameter={scale_parameter}", |
|
f"warmup_init={warmup_init}", |
|
f"clip_threshold={clip_threshold}" |
|
] |
|
kwargs["optimizer_args"] = node_args + extra_args |
|
kwargs["min_snr_gamma"] = min_snr_gamma if min_snr_gamma != 0.0 else None |
|
|
|
return (kwargs,) |
|
|
|
class FluxTrainerLossConfig: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"loss_type": (["l2", "huber","smooth_l1"], {"default": "huber", "tooltip": "The type of loss function to use"}), |
|
"huber_schedule": (["snr", "exponential", "constant"], {"default": "exponential", "tooltip": "The scheduling method for Huber loss (constant, exponential, or SNR-based). Only used when loss_type is 'huber' or 'smooth_l1'. default is snr"}), |
|
"huber_c": ("FLOAT",{"default": 0.25, "min": 0.0, "step": 0.01, "tooltip": "The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1"}), |
|
"huber_scale": ("FLOAT",{"default": 1.75, "min": 0.0, "step": 0.01, "tooltip": "The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0"}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("ARGS",) |
|
RETURN_NAMES = ("loss_args",) |
|
FUNCTION = "create_config" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def create_config(self, **kwargs): |
|
return (kwargs,) |
|
|
|
class OptimizerConfigProdigy: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"max_grad_norm": ("FLOAT",{"default": 0.0, "min": 0.0, "tooltip": "gradient clipping"}), |
|
"lr_scheduler": (["constant", "cosine", "cosine_with_restarts", "polynomial", "constant_with_warmup", "adafactor"], {"default": "constant", "tooltip": "learning rate scheduler"}), |
|
"lr_warmup_steps": ("INT",{"default": 0, "min": 0, "tooltip": "learning rate warmup steps"}), |
|
"lr_scheduler_num_cycles": ("INT",{"default": 1, "min": 1, "tooltip": "learning rate scheduler num cycles"}), |
|
"lr_scheduler_power": ("FLOAT",{"default": 1.0, "min": 0.0, "tooltip": "learning rate scheduler power"}), |
|
"weight_decay": ("FLOAT",{"default": 0.0, "step": 0.0001, "tooltip": "weight decay (L2 penalty)"}), |
|
"decouple": ("BOOLEAN",{"default": True, "tooltip": "use AdamW style weight decay"}), |
|
"use_bias_correction": ("BOOLEAN",{"default": False, "tooltip": "turn on Adam's bias correction"}), |
|
"min_snr_gamma": ("FLOAT",{"default": 5.0, "min": 0.0, "step": 0.01, "tooltip": "gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by the paper"}), |
|
"extra_optimizer_args": ("STRING",{"multiline": True, "default": "", "tooltip": "additional optimizer args"}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("ARGS",) |
|
RETURN_NAMES = ("optimizer_settings",) |
|
FUNCTION = "create_config" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def create_config(self, weight_decay, decouple, min_snr_gamma, use_bias_correction, extra_optimizer_args, **kwargs): |
|
kwargs["optimizer_type"] = "prodigy" |
|
extra_args = [arg.strip() for arg in extra_optimizer_args.strip().split('|') if arg.strip()] |
|
node_args = [ |
|
f"weight_decay={weight_decay}", |
|
f"decouple={decouple}", |
|
f"use_bias_correction={use_bias_correction}" |
|
] |
|
kwargs["optimizer_args"] = node_args + extra_args |
|
kwargs["min_snr_gamma"] = min_snr_gamma if min_snr_gamma != 0.0 else None |
|
|
|
return (kwargs,) |
|
|
|
class TrainNetworkConfig: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"network_type": (["lora", "LyCORIS/LoKr", "LyCORIS/Locon", "LyCORIS/LoHa"], {"default": "lora", "tooltip": "network type"}), |
|
"lycoris_preset": (["full", "full-lin", "attn-mlp", "attn-only"], {"default": "attn-mlp"}), |
|
"factor": ("INT",{"default": -1, "min": -1, "max": 16, "step": 1, "tooltip": "LoKr factor"}), |
|
"extra_network_args": ("STRING",{"multiline": True, "default": "", "tooltip": "additional network args"}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("NETWORK_CONFIG",) |
|
RETURN_NAMES = ("network_config",) |
|
FUNCTION = "create_config" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def create_config(self, network_type, extra_network_args, lycoris_preset, factor): |
|
|
|
extra_args = [arg.strip() for arg in extra_network_args.strip().split('|') if arg.strip()] |
|
|
|
if network_type == "lora": |
|
network_module = ".networks.lora" |
|
elif network_type == "LyCORIS/LoKr": |
|
network_module = ".lycoris.kohya" |
|
algo = "lokr" |
|
elif network_type == "LyCORIS/Locon": |
|
network_module = ".lycoris.kohya" |
|
algo = "locon" |
|
elif network_type == "LyCORIS/LoHa": |
|
network_module = ".lycoris.kohya" |
|
algo = "loha" |
|
|
|
network_args = [ |
|
f"algo={algo}", |
|
f"factor={factor}", |
|
f"preset={lycoris_preset}" |
|
] |
|
network_config = { |
|
"network_module": network_module, |
|
"network_args": network_args + extra_args |
|
} |
|
|
|
return (network_config,) |
|
|
|
class OptimizerConfigProdigyPlusScheduleFree: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"lr": ("FLOAT",{"default": 1.0, "min": 0.0, "step": 1e-7, "tooltip": "Learning rate adjustment parameter. Increases or decreases the Prodigy learning rate."}), |
|
"max_grad_norm": ("FLOAT",{"default": 0.0, "min": 0.0, "tooltip": "gradient clipping"}), |
|
"prodigy_steps": ("INT",{"default": 0, "min": 0, "tooltip": "Freeze Prodigy stepsize adjustments after a certain optimiser step."}), |
|
"d0": ("FLOAT",{"default": 1e-6, "min": 0.0,"step": 1e-7, "tooltip": "initial learning rate"}), |
|
"d_coeff": ("FLOAT",{"default": 1.0, "min": 0.0, "step": 1e-7, "tooltip": "Coefficient in the expression for the estimate of d (default 1.0). Values such as 0.5 and 2.0 typically work as well. Changing this parameter is the preferred way to tune the method."}), |
|
"split_groups": ("BOOLEAN",{"default": True, "tooltip": "Track individual adaptation values for each parameter group."}), |
|
|
|
|
|
"use_bias_correction": ("BOOLEAN",{"default": False, "tooltip": "Use the RAdam variant of schedule-free"}), |
|
"min_snr_gamma": ("FLOAT",{"default": 5.0, "min": 0.0, "step": 0.01, "tooltip": "gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by the paper"}), |
|
"use_stableadamw": ("BOOLEAN",{"default": True, "tooltip": "Scales parameter updates by the root-mean-square of the normalised gradient, in essence identical to Adafactor's gradient scaling. Set to False if the adaptive learning rate never improves."}), |
|
"use_cautious" : ("BOOLEAN",{"default": False, "tooltip": "Experimental. Perform 'cautious' updates, as proposed in https://arxiv.org/pdf/2411.16085. Modifies the update to isolate and boost values that align with the current gradient."}), |
|
"use_adopt": ("BOOLEAN",{"default": False, "tooltip": "Experimental. Performs a modified step where the second moment is updated after the parameter update, so as not to include the current gradient in the denominator. This is a partial implementation of ADOPT (https://arxiv.org/abs/2411.02853), as we don't have a first moment to use for the update."}), |
|
"use_grams": ("BOOLEAN",{"default": False, "tooltip": "Perform 'grams' updates, as proposed in https://arxiv.org/abs/2412.17107. Modifies the update using sign operations that align with the current gradient. Note that we do not have access to a first moment, so this deviates from the paper (we apply the sign directly to the update). May have a limited effect."}), |
|
"stochastic_rounding": ("BOOLEAN",{"default": True, "tooltip": "Use stochastic rounding for bfloat16 weights"}), |
|
"use_orthograd": ("BOOLEAN",{"default": False, "tooltip": "Experimental. Updates weights using the component of the gradient that is orthogonal to the current weight direction, as described in (https://arxiv.org/pdf/2501.04697). Can help prevent overfitting and improve generalisation."}), |
|
"use_focus ": ("BOOLEAN",{"default": False, "tooltip": "Experimental. Modifies the update step to better handle noise at large step sizes. (https://arxiv.org/abs/2501.12243). This method is incompatible with factorisation, Muon and Adam-atan2."}), |
|
"extra_optimizer_args": ("STRING",{"multiline": True, "default": "", "tooltip": "additional optimizer args"}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("ARGS",) |
|
RETURN_NAMES = ("optimizer_settings",) |
|
FUNCTION = "create_config" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def create_config(self, min_snr_gamma, use_bias_correction, extra_optimizer_args, **kwargs): |
|
kwargs["optimizer_type"] = "ProdigyPlusScheduleFree" |
|
kwargs["lr_scheduler"] = "constant" |
|
extra_args = [arg.strip() for arg in extra_optimizer_args.strip().split('|') if arg.strip()] |
|
node_args = [ |
|
f"use_bias_correction={use_bias_correction}", |
|
] |
|
kwargs["optimizer_args"] = node_args + extra_args |
|
kwargs["min_snr_gamma"] = min_snr_gamma if min_snr_gamma != 0.0 else None |
|
|
|
return (kwargs,) |
|
|
|
class InitFluxLoRATraining: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"flux_models": ("TRAIN_FLUX_MODELS",), |
|
"dataset": ("JSON",), |
|
"optimizer_settings": ("ARGS",), |
|
"output_name": ("STRING", {"default": "flux_lora", "multiline": False}), |
|
"output_dir": ("STRING", {"default": "flux_trainer_output", "multiline": False, "tooltip": "path to dataset, root is the 'ComfyUI' folder, with windows portable 'ComfyUI_windows_portable'"}), |
|
"network_dim": ("INT", {"default": 4, "min": 1, "max": 100000, "step": 1, "tooltip": "network dim"}), |
|
"network_alpha": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2048.0, "step": 0.01, "tooltip": "network alpha"}), |
|
"learning_rate": ("FLOAT", {"default": 4e-4, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "learning rate"}), |
|
"max_train_steps": ("INT", {"default": 1500, "min": 1, "max": 100000, "step": 1, "tooltip": "max number of training steps"}), |
|
"apply_t5_attn_mask": ("BOOLEAN", {"default": True, "tooltip": "apply t5 attention mask"}), |
|
"cache_latents": (["disk", "memory", "disabled"], {"tooltip": "caches text encoder outputs"}), |
|
"cache_text_encoder_outputs": (["disk", "memory", "disabled"], {"tooltip": "caches text encoder outputs"}), |
|
"blocks_to_swap": ("INT", {"default": 0, "tooltip": "Previously known as split_mode, number of blocks to swap to save memory, default to enable is 18"}), |
|
"weighting_scheme": (["logit_normal", "sigma_sqrt", "mode", "cosmap", "none"],), |
|
"logit_mean": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "mean to use when using the logit_normal weighting scheme"}), |
|
"logit_std": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01,"tooltip": "std to use when using the logit_normal weighting scheme"}), |
|
"mode_scale": ("FLOAT", {"default": 1.29, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Scale of mode weighting scheme. Only effective when using the mode as the weighting_scheme"}), |
|
"timestep_sampling": (["sigmoid", "uniform", "sigma", "shift", "flux_shift"], {"tooltip": "Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid (recommend value of 3.1582 for discrete_flow_shift)"}), |
|
"sigmoid_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1, "tooltip": "Scale factor for sigmoid timestep sampling (only used when timestep-sampling is sigmoid"}), |
|
"model_prediction_type": (["raw", "additive", "sigma_scaled"], {"tooltip": "How to interpret and process the model prediction: raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."}), |
|
"guidance_scale": ("FLOAT", {"default": 1.0, "min": 1.0, "max": 32.0, "step": 0.01, "tooltip": "guidance scale, for Flux training should be 1.0"}), |
|
"discrete_flow_shift": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.0001, "tooltip": "for the Euler Discrete Scheduler, default is 3.0"}), |
|
"highvram": ("BOOLEAN", {"default": False, "tooltip": "memory mode"}), |
|
"fp8_base": ("BOOLEAN", {"default": True, "tooltip": "use fp8 for base model"}), |
|
"gradient_dtype": (["fp32", "fp16", "bf16"], {"default": "fp32", "tooltip": "the actual dtype training uses"}), |
|
"save_dtype": (["fp32", "fp16", "bf16", "fp8_e4m3fn", "fp8_e5m2"], {"default": "bf16", "tooltip": "the dtype to save checkpoints as"}), |
|
"attention_mode": (["sdpa", "xformers", "disabled"], {"default": "sdpa", "tooltip": "memory efficient attention mode"}), |
|
"sample_prompts": ("STRING", {"multiline": True, "default": "illustration of a kitten | photograph of a turtle", "tooltip": "validation sample prompts, for multiple prompts, separate by `|`"}), |
|
}, |
|
"optional": { |
|
"additional_args": ("STRING", {"multiline": True, "default": "", "tooltip": "additional args to pass to the training command"}), |
|
"resume_args": ("ARGS", {"default": "", "tooltip": "resume args to pass to the training command"}), |
|
"train_text_encoder": (['disabled', 'clip_l', 'clip_l_fp8', 'clip_l+T5', 'clip_l+T5_fp8'], {"default": 'disabled', "tooltip": "also train the selected text encoders using specified dtype, T5 can not be trained without clip_l"}), |
|
"clip_l_lr": ("FLOAT", {"default": 0, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "text encoder learning rate"}), |
|
"T5_lr": ("FLOAT", {"default": 0, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "text encoder learning rate"}), |
|
"block_args": ("ARGS", {"default": "", "tooltip": "limit the blocks used in the LoRA"}), |
|
"gradient_checkpointing": (["enabled", "enabled_with_cpu_offloading", "disabled"], {"default": "enabled", "tooltip": "use gradient checkpointing"}), |
|
"loss_args": ("ARGS", {"default": "", "tooltip": "loss args"}), |
|
"network_config": ("NETWORK_CONFIG", {"tooltip": "additional network config"}), |
|
}, |
|
"hidden": { |
|
"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO" |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("NETWORKTRAINER", "INT", "KOHYA_ARGS",) |
|
RETURN_NAMES = ("network_trainer", "epochs_count", "args",) |
|
FUNCTION = "init_training" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def init_training(self, flux_models, dataset, optimizer_settings, sample_prompts, output_name, attention_mode, |
|
gradient_dtype, save_dtype, additional_args=None, resume_args=None, train_text_encoder='disabled', |
|
block_args=None, gradient_checkpointing="enabled", prompt=None, extra_pnginfo=None, clip_l_lr=0, T5_lr=0, loss_args=None, network_config=None, **kwargs): |
|
mm.soft_empty_cache() |
|
|
|
output_dir = os.path.abspath(kwargs.get("output_dir")) |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
total, used, free = shutil.disk_usage(output_dir) |
|
|
|
required_free_space = 2 * (2**30) |
|
if free <= required_free_space: |
|
raise ValueError(f"Insufficient disk space. Required: {required_free_space/2**30}GB. Available: {free/2**30}GB") |
|
|
|
dataset_config = dataset["datasets"] |
|
dataset_toml = toml.dumps(json.loads(dataset_config)) |
|
|
|
parser = train_network_setup_parser() |
|
flux_train_utils.add_flux_train_arguments(parser) |
|
|
|
if additional_args is not None: |
|
print(f"additional_args: {additional_args}") |
|
args, _ = parser.parse_known_args(args=shlex.split(additional_args)) |
|
else: |
|
args, _ = parser.parse_known_args() |
|
|
|
if kwargs.get("cache_latents") == "memory": |
|
kwargs["cache_latents"] = True |
|
kwargs["cache_latents_to_disk"] = False |
|
elif kwargs.get("cache_latents") == "disk": |
|
kwargs["cache_latents"] = True |
|
kwargs["cache_latents_to_disk"] = True |
|
kwargs["caption_dropout_rate"] = 0.0 |
|
kwargs["shuffle_caption"] = False |
|
kwargs["token_warmup_step"] = 0.0 |
|
kwargs["caption_tag_dropout_rate"] = 0.0 |
|
else: |
|
kwargs["cache_latents"] = False |
|
kwargs["cache_latents_to_disk"] = False |
|
|
|
if kwargs.get("cache_text_encoder_outputs") == "memory": |
|
kwargs["cache_text_encoder_outputs"] = True |
|
kwargs["cache_text_encoder_outputs_to_disk"] = False |
|
elif kwargs.get("cache_text_encoder_outputs") == "disk": |
|
kwargs["cache_text_encoder_outputs"] = True |
|
kwargs["cache_text_encoder_outputs_to_disk"] = True |
|
else: |
|
kwargs["cache_text_encoder_outputs"] = False |
|
kwargs["cache_text_encoder_outputs_to_disk"] = False |
|
|
|
if '|' in sample_prompts: |
|
prompts = sample_prompts.split('|') |
|
else: |
|
prompts = [sample_prompts] |
|
|
|
config_dict = { |
|
"sample_prompts": prompts, |
|
"save_precision": save_dtype, |
|
"mixed_precision": "bf16", |
|
"num_cpu_threads_per_process": 1, |
|
"pretrained_model_name_or_path": flux_models["transformer"], |
|
"clip_l": flux_models["clip_l"], |
|
"t5xxl": flux_models["t5"], |
|
"ae": flux_models["vae"], |
|
"save_model_as": "safetensors", |
|
"persistent_data_loader_workers": False, |
|
"max_data_loader_n_workers": 0, |
|
"seed": 42, |
|
"network_module": ".networks.lora_flux" if network_config is None else network_config["network_module"], |
|
"dataset_config": dataset_toml, |
|
"output_name": f"{output_name}_rank{kwargs.get('network_dim')}_{save_dtype}", |
|
"loss_type": "l2", |
|
"t5xxl_max_token_length": 512, |
|
"alpha_mask": dataset["alpha_mask"], |
|
"network_train_unet_only": True if train_text_encoder == 'disabled' else False, |
|
"fp8_base_unet": True if "fp8" in train_text_encoder else False, |
|
"disable_mmap_load_safetensors": False, |
|
"network_args": None if network_config is None else network_config["network_args"], |
|
} |
|
attention_settings = { |
|
"sdpa": {"mem_eff_attn": True, "xformers": False, "spda": True}, |
|
"xformers": {"mem_eff_attn": True, "xformers": True, "spda": False} |
|
} |
|
config_dict.update(attention_settings.get(attention_mode, {})) |
|
|
|
gradient_dtype_settings = { |
|
"fp16": {"full_fp16": True, "full_bf16": False, "mixed_precision": "fp16"}, |
|
"bf16": {"full_bf16": True, "full_fp16": False, "mixed_precision": "bf16"} |
|
} |
|
config_dict.update(gradient_dtype_settings.get(gradient_dtype, {})) |
|
|
|
if train_text_encoder != 'disabled': |
|
if T5_lr != "NaN": |
|
config_dict["text_encoder_lr"] = clip_l_lr |
|
if T5_lr != "NaN": |
|
config_dict["text_encoder_lr"] = [clip_l_lr, T5_lr] |
|
|
|
if gradient_checkpointing == "disabled": |
|
config_dict["gradient_checkpointing"] = False |
|
elif gradient_checkpointing == "enabled_with_cpu_offloading": |
|
config_dict["gradient_checkpointing"] = True |
|
config_dict["cpu_offload_checkpointing"] = True |
|
else: |
|
config_dict["gradient_checkpointing"] = True |
|
|
|
if flux_models["lora_path"]: |
|
config_dict["network_weights"] = flux_models["lora_path"] |
|
|
|
config_dict.update(kwargs) |
|
config_dict.update(optimizer_settings) |
|
|
|
if loss_args: |
|
config_dict.update(loss_args) |
|
|
|
if resume_args: |
|
config_dict.update(resume_args) |
|
|
|
for key, value in config_dict.items(): |
|
setattr(args, key, value) |
|
|
|
|
|
additional_network_args = [] |
|
|
|
if "T5" in train_text_encoder: |
|
additional_network_args.append("train_t5xxl=True") |
|
|
|
if block_args: |
|
additional_network_args.append(block_args["include"]) |
|
|
|
|
|
if hasattr(args, 'network_args') and isinstance(args.network_args, list): |
|
args.network_args.extend(additional_network_args) |
|
else: |
|
setattr(args, 'network_args', additional_network_args) |
|
|
|
saved_args_file_path = os.path.join(output_dir, f"{output_name}_args.json") |
|
with open(saved_args_file_path, 'w') as f: |
|
json.dump(vars(args), f, indent=4) |
|
|
|
|
|
metadata = {} |
|
if extra_pnginfo is not None: |
|
metadata.update(extra_pnginfo["workflow"]) |
|
|
|
saved_workflow_file_path = os.path.join(output_dir, f"{output_name}_workflow.json") |
|
with open(saved_workflow_file_path, 'w') as f: |
|
json.dump(metadata, f, indent=4) |
|
|
|
|
|
with torch.inference_mode(False): |
|
network_trainer = FluxNetworkTrainer() |
|
training_loop = network_trainer.init_train(args) |
|
|
|
epochs_count = network_trainer.num_train_epochs |
|
|
|
trainer = { |
|
"network_trainer": network_trainer, |
|
"training_loop": training_loop, |
|
} |
|
return (trainer, epochs_count, args) |
|
|
|
class InitFluxTraining: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"flux_models": ("TRAIN_FLUX_MODELS",), |
|
"dataset": ("JSON",), |
|
"optimizer_settings": ("ARGS",), |
|
"output_name": ("STRING", {"default": "flux", "multiline": False}), |
|
"output_dir": ("STRING", {"default": "flux_trainer_output", "multiline": False, "tooltip": "path to dataset, root is the 'ComfyUI' folder, with windows portable 'ComfyUI_windows_portable'"}), |
|
"learning_rate": ("FLOAT", {"default": 4e-6, "min": 0.0, "max": 10.0, "step": 0.000001, "tooltip": "learning rate"}), |
|
"max_train_steps": ("INT", {"default": 1500, "min": 1, "max": 100000, "step": 1, "tooltip": "max number of training steps"}), |
|
"apply_t5_attn_mask": ("BOOLEAN", {"default": True, "tooltip": "apply t5 attention mask"}), |
|
"t5xxl_max_token_length": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 8, "tooltip": "dev and LibreFlux uses 512, schnell 256"}), |
|
"cache_latents": (["disk", "memory", "disabled"], {"tooltip": "caches text encoder outputs"}), |
|
"cache_text_encoder_outputs": (["disk", "memory", "disabled"], {"tooltip": "caches text encoder outputs"}), |
|
"weighting_scheme": (["logit_normal", "sigma_sqrt", "mode", "cosmap", "none"],), |
|
"logit_mean": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "mean to use when using the logit_normal weighting scheme"}), |
|
"logit_std": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01,"tooltip": "std to use when using the logit_normal weighting scheme"}), |
|
"mode_scale": ("FLOAT", {"default": 1.29, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Scale of mode weighting scheme. Only effective when using the mode as the weighting_scheme"}), |
|
"loss_type": (["l1", "l2", "huber", "smooth_l1"], {"default": "l2", "tooltip": "loss type"}), |
|
"timestep_sampling": (["sigmoid", "uniform", "sigma", "shift", "flux_shift"], {"tooltip": "Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid (recommend value of 3.1582 for discrete_flow_shift)"}), |
|
"sigmoid_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1, "tooltip": "Scale factor for sigmoid timestep sampling (only used when timestep-sampling is sigmoid"}), |
|
"model_prediction_type": (["raw", "additive", "sigma_scaled"], {"tooltip": "How to interpret and process the model prediction: raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)"}), |
|
"cpu_offload_checkpointing": ("BOOLEAN", {"default": True, "tooltip": "offload the gradient checkpointing to CPU. This reduces VRAM usage for about 2GB"}), |
|
"optimizer_fusing": (['fused_backward_pass', 'blockwise_fused_optimizers'], {"tooltip": "reduces memory use"}), |
|
"blocks_to_swap": ("INT", {"default": 0, "min": 0, "max": 100, "step": 1, "tooltip": "Sets the number of blocks (~640MB) to swap during the forward and backward passes, increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."}), |
|
"guidance_scale": ("FLOAT", {"default": 1.0, "min": 1.0, "max": 32.0, "step": 0.01, "tooltip": "guidance scale"}), |
|
"discrete_flow_shift": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.0001, "tooltip": "for the Euler Discrete Scheduler, default is 3.0"}), |
|
"highvram": ("BOOLEAN", {"default": False, "tooltip": "memory mode"}), |
|
"fp8_base": ("BOOLEAN", {"default": False, "tooltip": "use fp8 for base model"}), |
|
"gradient_dtype": (["fp32", "fp16", "bf16"], {"default": "bf16", "tooltip": "to use the full fp16/bf16 training"}), |
|
"save_dtype": (["fp32", "fp16", "bf16", "fp8_e4m3fn"], {"default": "bf16", "tooltip": "the dtype to save checkpoints as"}), |
|
"attention_mode": (["sdpa", "xformers", "disabled"], {"default": "sdpa", "tooltip": "memory efficient attention mode"}), |
|
"sample_prompts": ("STRING", {"multiline": True, "default": "illustration of a kitten | photograph of a turtle", "tooltip": "validation sample prompts, for multiple prompts, separate by `|`"}), |
|
}, |
|
"optional": { |
|
"additional_args": ("STRING", {"multiline": True, "default": "", "tooltip": "additional args to pass to the training command"}), |
|
"resume_args": ("ARGS", {"default": "", "tooltip": "resume args to pass to the training command"}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("NETWORKTRAINER", "INT", "KOHYA_ARGS") |
|
RETURN_NAMES = ("network_trainer", "epochs_count", "args") |
|
FUNCTION = "init_training" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def init_training(self, flux_models, optimizer_settings, dataset, sample_prompts, output_name, |
|
attention_mode, gradient_dtype, save_dtype, optimizer_fusing, additional_args=None, resume_args=None, **kwargs,): |
|
mm.soft_empty_cache() |
|
|
|
output_dir = os.path.abspath(kwargs.get("output_dir")) |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
total, used, free = shutil.disk_usage(output_dir) |
|
required_free_space = 25 * (2**30) |
|
if free <= required_free_space: |
|
raise ValueError(f"Most likely insufficient disk space to complete training. Required: {required_free_space/2**30}GB. Available: {free/2**30}GB") |
|
|
|
dataset_config = dataset["datasets"] |
|
dataset_toml = toml.dumps(json.loads(dataset_config)) |
|
|
|
parser = train_setup_parser() |
|
flux_train_utils.add_flux_train_arguments(parser) |
|
|
|
if additional_args is not None: |
|
print(f"additional_args: {additional_args}") |
|
args, _ = parser.parse_known_args(args=shlex.split(additional_args)) |
|
else: |
|
args, _ = parser.parse_known_args() |
|
|
|
if kwargs.get("cache_latents") == "memory": |
|
kwargs["cache_latents"] = True |
|
kwargs["cache_latents_to_disk"] = False |
|
elif kwargs.get("cache_latents") == "disk": |
|
kwargs["cache_latents"] = True |
|
kwargs["cache_latents_to_disk"] = True |
|
kwargs["caption_dropout_rate"] = 0.0 |
|
kwargs["shuffle_caption"] = False |
|
kwargs["token_warmup_step"] = 0.0 |
|
kwargs["caption_tag_dropout_rate"] = 0.0 |
|
else: |
|
kwargs["cache_latents"] = False |
|
kwargs["cache_latents_to_disk"] = False |
|
|
|
if kwargs.get("cache_text_encoder_outputs") == "memory": |
|
kwargs["cache_text_encoder_outputs"] = True |
|
kwargs["cache_text_encoder_outputs_to_disk"] = False |
|
elif kwargs.get("cache_text_encoder_outputs") == "disk": |
|
kwargs["cache_text_encoder_outputs"] = True |
|
kwargs["cache_text_encoder_outputs_to_disk"] = True |
|
else: |
|
kwargs["cache_text_encoder_outputs"] = False |
|
kwargs["cache_text_encoder_outputs_to_disk"] = False |
|
|
|
if '|' in sample_prompts: |
|
prompts = sample_prompts.split('|') |
|
else: |
|
prompts = [sample_prompts] |
|
|
|
config_dict = { |
|
"sample_prompts": prompts, |
|
"save_precision": save_dtype, |
|
"mixed_precision": "bf16", |
|
"num_cpu_threads_per_process": 1, |
|
"pretrained_model_name_or_path": flux_models["transformer"], |
|
"clip_l": flux_models["clip_l"], |
|
"t5xxl": flux_models["t5"], |
|
"ae": flux_models["vae"], |
|
"save_model_as": "safetensors", |
|
"persistent_data_loader_workers": False, |
|
"max_data_loader_n_workers": 0, |
|
"seed": 42, |
|
"gradient_checkpointing": True, |
|
"dataset_config": dataset_toml, |
|
"output_name": f"{output_name}_{save_dtype}", |
|
"mem_eff_save": True, |
|
"disable_mmap_load_safetensors": True, |
|
|
|
} |
|
optimizer_fusing_settings = { |
|
"fused_backward_pass": {"fused_backward_pass": True, "blockwise_fused_optimizers": False}, |
|
"blockwise_fused_optimizers": {"fused_backward_pass": False, "blockwise_fused_optimizers": True} |
|
} |
|
config_dict.update(optimizer_fusing_settings.get(optimizer_fusing, {})) |
|
|
|
attention_settings = { |
|
"sdpa": {"mem_eff_attn": True, "xformers": False, "spda": True}, |
|
"xformers": {"mem_eff_attn": True, "xformers": True, "spda": False} |
|
} |
|
config_dict.update(attention_settings.get(attention_mode, {})) |
|
|
|
gradient_dtype_settings = { |
|
"fp16": {"full_fp16": True, "full_bf16": False, "mixed_precision": "fp16"}, |
|
"bf16": {"full_bf16": True, "full_fp16": False, "mixed_precision": "bf16"} |
|
} |
|
config_dict.update(gradient_dtype_settings.get(gradient_dtype, {})) |
|
|
|
config_dict.update(kwargs) |
|
config_dict.update(optimizer_settings) |
|
|
|
if resume_args: |
|
config_dict.update(resume_args) |
|
|
|
for key, value in config_dict.items(): |
|
setattr(args, key, value) |
|
|
|
with torch.inference_mode(False): |
|
network_trainer = FluxTrainer() |
|
training_loop = network_trainer.init_train(args) |
|
|
|
epochs_count = network_trainer.num_train_epochs |
|
|
|
|
|
saved_args_file_path = os.path.join(output_dir, f"{output_name}_args.json") |
|
with open(saved_args_file_path, 'w') as f: |
|
json.dump(vars(args), f, indent=4) |
|
|
|
trainer = { |
|
"network_trainer": network_trainer, |
|
"training_loop": training_loop, |
|
} |
|
return (trainer, epochs_count, args) |
|
|
|
class InitFluxTrainingFromPreset: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"flux_models": ("TRAIN_FLUX_MODELS",), |
|
"dataset_settings": ("TOML_DATASET",), |
|
"preset_args": ("KOHYA_ARGS",), |
|
"output_name": ("STRING", {"default": "flux", "multiline": False}), |
|
"output_dir": ("STRING", {"default": "flux_trainer_output", "multiline": False, "tooltip": "output directory, root is ComfyUI folder"}), |
|
"sample_prompts": ("STRING", {"multiline": True, "default": "illustration of a kitten | photograph of a turtle", "tooltip": "validation sample prompts, for multiple prompts, separate by `|`"}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("NETWORKTRAINER", "INT", "STRING", "KOHYA_ARGS") |
|
RETURN_NAMES = ("network_trainer", "epochs_count", "output_path", "args") |
|
FUNCTION = "init_training" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def init_training(self, flux_models, dataset_settings, sample_prompts, output_name, preset_args, **kwargs,): |
|
mm.soft_empty_cache() |
|
|
|
dataset = dataset_settings["dataset"] |
|
dataset_repeats = dataset_settings["repeats"] |
|
|
|
parser = train_setup_parser() |
|
args, _ = parser.parse_known_args() |
|
for key, value in vars(preset_args).items(): |
|
setattr(args, key, value) |
|
|
|
output_dir = os.path.join(script_directory, "output") |
|
if '|' in sample_prompts: |
|
prompts = sample_prompts.split('|') |
|
else: |
|
prompts = [sample_prompts] |
|
|
|
width, height = toml.loads(dataset)["datasets"][0]["resolution"] |
|
config_dict = { |
|
"sample_prompts": prompts, |
|
"dataset_repeats": dataset_repeats, |
|
"num_cpu_threads_per_process": 1, |
|
"pretrained_model_name_or_path": flux_models["transformer"], |
|
"clip_l": flux_models["clip_l"], |
|
"t5xxl": flux_models["t5"], |
|
"ae": flux_models["vae"], |
|
"save_model_as": "safetensors", |
|
"persistent_data_loader_workers": False, |
|
"max_data_loader_n_workers": 0, |
|
"seed": 42, |
|
"gradient_checkpointing": True, |
|
"dataset_config": dataset, |
|
"output_dir": output_dir, |
|
"output_name": f"{output_name}_rank{kwargs.get('network_dim')}_{args.save_precision}", |
|
"width" : int(width), |
|
"height" : int(height), |
|
|
|
} |
|
|
|
config_dict.update(kwargs) |
|
|
|
for key, value in config_dict.items(): |
|
setattr(args, key, value) |
|
|
|
with torch.inference_mode(False): |
|
network_trainer = FluxNetworkTrainer() |
|
training_loop = network_trainer.init_train(args) |
|
|
|
final_output_path = os.path.join(output_dir, output_name) |
|
|
|
epochs_count = network_trainer.num_train_epochs |
|
|
|
trainer = { |
|
"network_trainer": network_trainer, |
|
"training_loop": training_loop, |
|
} |
|
return (trainer, epochs_count, final_output_path, args) |
|
|
|
class FluxTrainLoop: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"network_trainer": ("NETWORKTRAINER",), |
|
"steps": ("INT", {"default": 1, "min": 1, "max": 10000, "step": 1, "tooltip": "the step point in training to validate/save"}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("NETWORKTRAINER", "INT",) |
|
RETURN_NAMES = ("network_trainer", "steps",) |
|
FUNCTION = "train" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def train(self, network_trainer, steps): |
|
with torch.inference_mode(False): |
|
training_loop = network_trainer["training_loop"] |
|
network_trainer = network_trainer["network_trainer"] |
|
initial_global_step = network_trainer.global_step |
|
|
|
target_global_step = network_trainer.global_step + steps |
|
comfy_pbar = comfy.utils.ProgressBar(steps) |
|
network_trainer.comfy_pbar = comfy_pbar |
|
|
|
network_trainer.optimizer_train_fn() |
|
|
|
while network_trainer.global_step < target_global_step: |
|
steps_done = training_loop( |
|
break_at_steps = target_global_step, |
|
epoch = network_trainer.current_epoch.value, |
|
) |
|
|
|
|
|
|
|
if network_trainer.global_step >= network_trainer.args.max_train_steps: |
|
break |
|
|
|
trainer = { |
|
"network_trainer": network_trainer, |
|
"training_loop": training_loop, |
|
} |
|
return (trainer, network_trainer.global_step) |
|
|
|
class FluxTrainAndValidateLoop: |
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return {"required": { |
|
"network_trainer": ("NETWORKTRAINER",), |
|
"validate_at_steps": ("INT", {"default": 250, "min": 1, "max": 10000, "step": 1, "tooltip": "the step point in training to validate/save"}), |
|
"save_at_steps": ("INT", {"default": 250, "min": 1, "max": 10000, "step": 1, "tooltip": "the step point in training to validate/save"}), |
|
}, |
|
"optional": { |
|
"validation_settings": ("VALSETTINGS",), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("NETWORKTRAINER", "INT",) |
|
RETURN_NAMES = ("network_trainer", "steps",) |
|
FUNCTION = "train" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def train(self, network_trainer, validate_at_steps, save_at_steps, validation_settings=None): |
|
with torch.inference_mode(False): |
|
training_loop = network_trainer["training_loop"] |
|
network_trainer = network_trainer["network_trainer"] |
|
|
|
target_global_step = network_trainer.args.max_train_steps |
|
comfy_pbar = comfy.utils.ProgressBar(target_global_step) |
|
network_trainer.comfy_pbar = comfy_pbar |
|
|
|
network_trainer.optimizer_train_fn() |
|
|
|
while network_trainer.global_step < target_global_step: |
|
next_validate_step = ((network_trainer.global_step // validate_at_steps) + 1) * validate_at_steps |
|
next_save_step = ((network_trainer.global_step // save_at_steps) + 1) * save_at_steps |
|
|
|
steps_done = training_loop( |
|
break_at_steps=min(next_validate_step, next_save_step), |
|
epoch=network_trainer.current_epoch.value, |
|
) |
|
|
|
|
|
if network_trainer.global_step % validate_at_steps == 0: |
|
self.validate(network_trainer, validation_settings) |
|
|
|
|
|
if network_trainer.global_step % save_at_steps == 0: |
|
self.save(network_trainer) |
|
|
|
|
|
if network_trainer.global_step >= network_trainer.args.max_train_steps: |
|
break |
|
|
|
trainer = { |
|
"network_trainer": network_trainer, |
|
"training_loop": training_loop, |
|
} |
|
return (trainer, network_trainer.global_step) |
|
|
|
def validate(self, network_trainer, validation_settings=None): |
|
params = ( |
|
network_trainer.current_epoch.value, |
|
network_trainer.global_step, |
|
validation_settings |
|
) |
|
network_trainer.optimizer_eval_fn() |
|
image_tensors = network_trainer.sample_images(*params) |
|
network_trainer.optimizer_train_fn() |
|
print("Validating at step:", network_trainer.global_step) |
|
|
|
def save(self, network_trainer): |
|
ckpt_name = train_util.get_step_ckpt_name(network_trainer.args, "." + network_trainer.args.save_model_as, network_trainer.global_step) |
|
network_trainer.optimizer_eval_fn() |
|
network_trainer.save_model(ckpt_name, network_trainer.accelerator.unwrap_model(network_trainer.network), network_trainer.global_step, network_trainer.current_epoch.value + 1) |
|
network_trainer.optimizer_train_fn() |
|
print("Saving at step:", network_trainer.global_step) |
|
|
|
class FluxTrainSave: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"network_trainer": ("NETWORKTRAINER",), |
|
"save_state": ("BOOLEAN", {"default": False, "tooltip": "save the whole model state as well"}), |
|
"copy_to_comfy_lora_folder": ("BOOLEAN", {"default": False, "tooltip": "copy the lora model to the comfy lora folder"}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("NETWORKTRAINER", "STRING", "INT",) |
|
RETURN_NAMES = ("network_trainer","lora_path", "steps",) |
|
FUNCTION = "save" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def save(self, network_trainer, save_state, copy_to_comfy_lora_folder): |
|
import shutil |
|
with torch.inference_mode(False): |
|
trainer = network_trainer["network_trainer"] |
|
global_step = trainer.global_step |
|
|
|
ckpt_name = train_util.get_step_ckpt_name(trainer.args, "." + trainer.args.save_model_as, global_step) |
|
trainer.save_model(ckpt_name, trainer.accelerator.unwrap_model(trainer.network), global_step, trainer.current_epoch.value + 1) |
|
|
|
remove_step_no = train_util.get_remove_step_no(trainer.args, global_step) |
|
if remove_step_no is not None: |
|
remove_ckpt_name = train_util.get_step_ckpt_name(trainer.args, "." + trainer.args.save_model_as, remove_step_no) |
|
trainer.remove_model(remove_ckpt_name) |
|
|
|
if save_state: |
|
train_util.save_and_remove_state_stepwise(trainer.args, trainer.accelerator, global_step) |
|
|
|
lora_path = os.path.join(trainer.args.output_dir, ckpt_name) |
|
if copy_to_comfy_lora_folder: |
|
destination_dir = os.path.join(folder_paths.models_dir, "loras", "flux_trainer") |
|
os.makedirs(destination_dir, exist_ok=True) |
|
shutil.copy(lora_path, os.path.join(destination_dir, ckpt_name)) |
|
|
|
|
|
return (network_trainer, lora_path, global_step) |
|
|
|
class FluxTrainSaveModel: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"network_trainer": ("NETWORKTRAINER",), |
|
"copy_to_comfy_model_folder": ("BOOLEAN", {"default": False, "tooltip": "copy the lora model to the comfy lora folder"}), |
|
"end_training": ("BOOLEAN", {"default": False, "tooltip": "end the training"}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("NETWORKTRAINER", "STRING", "INT",) |
|
RETURN_NAMES = ("network_trainer","model_path", "steps",) |
|
FUNCTION = "save" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def save(self, network_trainer, copy_to_comfy_model_folder, end_training): |
|
import shutil |
|
with torch.inference_mode(False): |
|
trainer = network_trainer["network_trainer"] |
|
global_step = trainer.global_step |
|
|
|
trainer.optimizer_eval_fn() |
|
|
|
ckpt_name = train_util.get_step_ckpt_name(trainer.args, "." + trainer.args.save_model_as, global_step) |
|
flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( |
|
trainer.args, |
|
False, |
|
trainer.accelerator, |
|
trainer.save_dtype, |
|
trainer.current_epoch.value, |
|
trainer.num_train_epochs, |
|
global_step, |
|
trainer.accelerator.unwrap_model(trainer.unet) |
|
) |
|
|
|
model_path = os.path.join(trainer.args.output_dir, ckpt_name) |
|
if copy_to_comfy_model_folder: |
|
shutil.copy(model_path, os.path.join(folder_paths.models_dir, "diffusion_models", "flux_trainer", ckpt_name)) |
|
model_path = os.path.join(folder_paths.models_dir, "diffusion_models", "flux_trainer", ckpt_name) |
|
if end_training: |
|
trainer.accelerator.end_training() |
|
|
|
return (network_trainer, model_path, global_step) |
|
|
|
class FluxTrainEnd: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"network_trainer": ("NETWORKTRAINER",), |
|
"save_state": ("BOOLEAN", {"default": True}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("STRING", "STRING", "STRING",) |
|
RETURN_NAMES = ("lora_name", "metadata", "lora_path",) |
|
FUNCTION = "endtrain" |
|
CATEGORY = "FluxTrainer" |
|
OUTPUT_NODE = True |
|
|
|
def endtrain(self, network_trainer, save_state): |
|
with torch.inference_mode(False): |
|
training_loop = network_trainer["training_loop"] |
|
network_trainer = network_trainer["network_trainer"] |
|
|
|
network_trainer.metadata["ss_epoch"] = str(network_trainer.num_train_epochs) |
|
network_trainer.metadata["ss_training_finished_at"] = str(time.time()) |
|
|
|
network = network_trainer.accelerator.unwrap_model(network_trainer.network) |
|
|
|
network_trainer.accelerator.end_training() |
|
network_trainer.optimizer_eval_fn() |
|
|
|
if save_state: |
|
train_util.save_state_on_train_end(network_trainer.args, network_trainer.accelerator) |
|
|
|
ckpt_name = train_util.get_last_ckpt_name(network_trainer.args, "." + network_trainer.args.save_model_as) |
|
network_trainer.save_model(ckpt_name, network, network_trainer.global_step, network_trainer.num_train_epochs, force_sync_upload=True) |
|
logger.info("model saved.") |
|
|
|
final_lora_name = str(network_trainer.args.output_name) |
|
final_lora_path = os.path.join(network_trainer.args.output_dir, ckpt_name) |
|
|
|
|
|
metadata = json.dumps(network_trainer.metadata, indent=2) |
|
|
|
training_loop = None |
|
network_trainer = None |
|
mm.soft_empty_cache() |
|
|
|
return (final_lora_name, metadata, final_lora_path) |
|
|
|
class FluxTrainResume: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"load_state_path": ("STRING", {"default": "", "multiline": True, "tooltip": "path to load state from"}), |
|
"skip_until_initial_step" : ("BOOLEAN", {"default": False}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("ARGS", ) |
|
RETURN_NAMES = ("resume_args", ) |
|
FUNCTION = "resume" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def resume(self, load_state_path, skip_until_initial_step): |
|
resume_args ={ |
|
"resume": load_state_path, |
|
"skip_until_initial_step": skip_until_initial_step |
|
} |
|
|
|
return (resume_args, ) |
|
|
|
class FluxTrainBlockSelect: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"include": ("STRING", {"default": "lora_unet_single_blocks_20_linear2", "multiline": True, "tooltip": "blocks to include in the LoRA network, to select multiple blocks either input them as "}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("ARGS", ) |
|
RETURN_NAMES = ("block_args", ) |
|
FUNCTION = "block_select" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def block_select(self, include): |
|
import re |
|
|
|
|
|
elements = include.split(',') |
|
|
|
|
|
blocks = [] |
|
|
|
|
|
pattern = re.compile(r'\((\d+)-(\d+)\)') |
|
|
|
|
|
prefix_suffix_pattern = re.compile(r'(.*)_blocks_(.*)') |
|
|
|
for element in elements: |
|
element = element.strip() |
|
match = prefix_suffix_pattern.match(element) |
|
if match: |
|
prefix = match.group(1) + "_blocks_" |
|
suffix = match.group(2) |
|
matches = pattern.findall(suffix) |
|
if matches: |
|
for start, end in matches: |
|
|
|
blocks.extend([f"{prefix}{i}{suffix.replace(f'({start}-{end})', '', 1)}" for i in range(int(start), int(end) + 1)]) |
|
else: |
|
|
|
blocks.append(element) |
|
else: |
|
blocks.append(element) |
|
|
|
|
|
include_string = ','.join(blocks) |
|
|
|
block_args = { |
|
"include": f"only_if_contains={include_string}", |
|
} |
|
|
|
return (block_args, ) |
|
|
|
class FluxTrainValidationSettings: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"steps": ("INT", {"default": 20, "min": 1, "max": 256, "step": 1, "tooltip": "sampling steps"}), |
|
"width": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 8, "tooltip": "image width"}), |
|
"height": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 8, "tooltip": "image height"}), |
|
"guidance_scale": ("FLOAT", {"default": 3.5, "min": 1.0, "max": 32.0, "step": 0.05, "tooltip": "guidance scale"}), |
|
"seed": ("INT", {"default": 42,"min": 0, "max": 0xffffffffffffffff, "step": 1}), |
|
"shift": ("BOOLEAN", {"default": True, "tooltip": "shift the schedule to favor high timesteps for higher signal images"}), |
|
"base_shift": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}), |
|
"max_shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 10.0, "step": 0.01}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("VALSETTINGS", ) |
|
RETURN_NAMES = ("validation_settings", ) |
|
FUNCTION = "set" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def set(self, **kwargs): |
|
validation_settings = kwargs |
|
print(validation_settings) |
|
|
|
return (validation_settings,) |
|
|
|
class FluxTrainValidate: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"network_trainer": ("NETWORKTRAINER",), |
|
}, |
|
"optional": { |
|
"validation_settings": ("VALSETTINGS",), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("NETWORKTRAINER", "IMAGE",) |
|
RETURN_NAMES = ("network_trainer", "validation_images",) |
|
FUNCTION = "validate" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def validate(self, network_trainer, validation_settings=None): |
|
training_loop = network_trainer["training_loop"] |
|
network_trainer = network_trainer["network_trainer"] |
|
|
|
params = ( |
|
network_trainer.current_epoch.value, |
|
network_trainer.global_step, |
|
validation_settings |
|
) |
|
network_trainer.optimizer_eval_fn() |
|
with torch.inference_mode(False): |
|
image_tensors = network_trainer.sample_images(*params) |
|
|
|
trainer = { |
|
"network_trainer": network_trainer, |
|
"training_loop": training_loop, |
|
} |
|
return (trainer, (0.5 * (image_tensors + 1.0)).cpu().float(),) |
|
|
|
class VisualizeLoss: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"network_trainer": ("NETWORKTRAINER",), |
|
"plot_style": (plt.style.available,{"default": 'default', "tooltip": "matplotlib plot style"}), |
|
"window_size": ("INT", {"default": 100, "min": 0, "max": 10000, "step": 1, "tooltip": "the window size of the moving average"}), |
|
"normalize_y": ("BOOLEAN", {"default": True, "tooltip": "normalize the y-axis to 0"}), |
|
"width": ("INT", {"default": 768, "min": 256, "max": 4096, "step": 2, "tooltip": "width of the plot in pixels"}), |
|
"height": ("INT", {"default": 512, "min": 256, "max": 4096, "step": 2, "tooltip": "height of the plot in pixels"}), |
|
"log_scale": ("BOOLEAN", {"default": False, "tooltip": "use log scale on the y-axis"}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE", "FLOAT",) |
|
RETURN_NAMES = ("plot", "loss_list",) |
|
FUNCTION = "draw" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def draw(self, network_trainer, window_size, plot_style, normalize_y, width, height, log_scale): |
|
import numpy as np |
|
loss_values = network_trainer["network_trainer"].loss_recorder.global_loss_list |
|
|
|
|
|
def moving_average(values, window_size): |
|
return np.convolve(values, np.ones(window_size) / window_size, mode='valid') |
|
if window_size > 0: |
|
loss_values = moving_average(loss_values, window_size) |
|
|
|
plt.style.use(plot_style) |
|
|
|
|
|
width_inches = width / 100 |
|
height_inches = height / 100 |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(width_inches, height_inches)) |
|
ax.plot(loss_values, label='Training Loss') |
|
ax.set_xlabel('Step') |
|
ax.set_ylabel('Loss') |
|
if normalize_y: |
|
plt.ylim(bottom=0) |
|
if log_scale: |
|
ax.set_yscale('log') |
|
ax.set_title('Training Loss Over Time') |
|
ax.legend() |
|
ax.grid(True) |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png') |
|
plt.close(fig) |
|
buf.seek(0) |
|
|
|
image = Image.open(buf).convert('RGB') |
|
|
|
image_tensor = transforms.ToTensor()(image) |
|
image_tensor = image_tensor.unsqueeze(0).permute(0, 2, 3, 1).cpu().float() |
|
|
|
return image_tensor, loss_values, |
|
|
|
class FluxKohyaInferenceSampler: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return {"required": { |
|
"flux_models": ("TRAIN_FLUX_MODELS",), |
|
"lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}), |
|
"lora_method": (["apply", "merge"], {"tooltip": "whether to apply or merge the lora weights"}), |
|
"steps": ("INT", {"default": 20, "min": 1, "max": 256, "step": 1, "tooltip": "sampling steps"}), |
|
"width": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 8, "tooltip": "image width"}), |
|
"height": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 8, "tooltip": "image height"}), |
|
"guidance_scale": ("FLOAT", {"default": 3.5, "min": 1.0, "max": 32.0, "step": 0.05, "tooltip": "guidance scale"}), |
|
"seed": ("INT", {"default": 42,"min": 0, "max": 0xffffffffffffffff, "step": 1}), |
|
"use_fp8": ("BOOLEAN", {"default": True, "tooltip": "use fp8 weights"}), |
|
"apply_t5_attn_mask": ("BOOLEAN", {"default": True, "tooltip": "use t5 attention mask"}), |
|
"prompt": ("STRING", {"multiline": True, "default": "illustration of a kitten", "tooltip": "prompt"}), |
|
|
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE", ) |
|
RETURN_NAMES = ("image", ) |
|
FUNCTION = "sample" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def sample(self, flux_models, lora_name, steps, width, height, guidance_scale, seed, prompt, use_fp8, lora_method, apply_t5_attn_mask): |
|
|
|
from .library import flux_utils as flux_utils |
|
from .library import strategy_flux as strategy_flux |
|
from .networks import lora_flux as lora_flux |
|
from typing import List, Optional, Callable |
|
from tqdm import tqdm |
|
import einops |
|
import math |
|
import accelerate |
|
import gc |
|
|
|
device = "cuda" |
|
|
|
|
|
if use_fp8: |
|
accelerator = accelerate.Accelerator(mixed_precision="bf16") |
|
dtype = torch.float8_e4m3fn |
|
else: |
|
dtype = torch.float16 |
|
accelerator = None |
|
loading_device = "cpu" |
|
ae_dtype = torch.bfloat16 |
|
|
|
pretrained_model_name_or_path = flux_models["transformer"] |
|
clip_l = flux_models["clip_l"] |
|
t5xxl = flux_models["t5"] |
|
ae = flux_models["vae"] |
|
lora_path = folder_paths.get_full_path("loras", lora_name) |
|
|
|
|
|
logger.info(f"Loading clip_l from {clip_l}...") |
|
clip_l = flux_utils.load_clip_l(clip_l, None, loading_device) |
|
clip_l.eval() |
|
|
|
logger.info(f"Loading t5xxl from {t5xxl}...") |
|
t5xxl = flux_utils.load_t5xxl(t5xxl, None, loading_device) |
|
t5xxl.eval() |
|
|
|
if use_fp8: |
|
clip_l = accelerator.prepare(clip_l) |
|
t5xxl = accelerator.prepare(t5xxl) |
|
|
|
t5xxl_max_length = 512 |
|
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length) |
|
encoding_strategy = strategy_flux.FluxTextEncodingStrategy() |
|
|
|
|
|
model = flux_utils.load_flow_model("dev", pretrained_model_name_or_path, dtype, loading_device) |
|
model.eval() |
|
logger.info(f"Casting model to {dtype}") |
|
model.to(dtype) |
|
if use_fp8: |
|
model = accelerator.prepare(model) |
|
|
|
|
|
ae = flux_utils.load_ae("dev", ae, ae_dtype, loading_device) |
|
ae.eval() |
|
|
|
|
|
|
|
lora_models: List[lora_flux.LoRANetwork] = [] |
|
multiplier = 1.0 |
|
|
|
lora_model, weights_sd = lora_flux.create_network_from_weights( |
|
multiplier, lora_path, ae, [clip_l, t5xxl], model, None, True |
|
) |
|
if lora_method == "merge": |
|
lora_model.merge_to([clip_l, t5xxl], model, weights_sd) |
|
elif lora_method == "apply": |
|
lora_model.apply_to([clip_l, t5xxl], model) |
|
info = lora_model.load_state_dict(weights_sd, strict=True) |
|
logger.info(f"Loaded LoRA weights from {lora_name}: {info}") |
|
lora_model.eval() |
|
lora_model.to(device) |
|
lora_models.append(lora_model) |
|
|
|
|
|
packed_latent_height, packed_latent_width = math.ceil(height / 16), math.ceil(width / 16) |
|
noise = torch.randn( |
|
1, |
|
packed_latent_height * packed_latent_width, |
|
16 * 2 * 2, |
|
device=device, |
|
dtype=ae_dtype, |
|
generator=torch.Generator(device=device).manual_seed(seed), |
|
) |
|
|
|
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width) |
|
|
|
|
|
logger.info("Encoding prompts...") |
|
tokens_and_masks = tokenize_strategy.tokenize(prompt) |
|
clip_l = clip_l.to(device) |
|
t5xxl = t5xxl.to(device) |
|
with torch.no_grad(): |
|
if use_fp8: |
|
clip_l.to(ae_dtype) |
|
t5xxl.to(ae_dtype) |
|
with accelerator.autocast(): |
|
l_pooled, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( |
|
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, apply_t5_attn_mask |
|
) |
|
else: |
|
with torch.autocast(device_type=device.type, dtype=dtype): |
|
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) |
|
with torch.autocast(device_type=device.type, dtype=dtype): |
|
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( |
|
tokenize_strategy, [None, t5xxl], tokens_and_masks, apply_t5_attn_mask |
|
) |
|
|
|
if torch.isnan(l_pooled).any(): |
|
raise ValueError("NaN in l_pooled") |
|
|
|
if torch.isnan(t5_out).any(): |
|
raise ValueError("NaN in t5_out") |
|
|
|
|
|
clip_l = clip_l.cpu() |
|
t5xxl = t5xxl.cpu() |
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
logger.info("Generating image...") |
|
model = model.to(device) |
|
print("MODEL DTYPE: ", model.dtype) |
|
|
|
img_ids = img_ids.to(device) |
|
t5_attn_mask = t5_attn_mask.to(device) if apply_t5_attn_mask else None |
|
def time_shift(mu: float, sigma: float, t: torch.Tensor): |
|
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) |
|
|
|
|
|
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: |
|
m = (y2 - y1) / (x2 - x1) |
|
b = y1 - m * x1 |
|
return lambda x: m * x + b |
|
|
|
|
|
def get_schedule( |
|
num_steps: int, |
|
image_seq_len: int, |
|
base_shift: float = 0.5, |
|
max_shift: float = 1.15, |
|
shift: bool = True, |
|
) -> list[float]: |
|
|
|
timesteps = torch.linspace(1, 0, num_steps + 1) |
|
|
|
|
|
if shift: |
|
|
|
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) |
|
timesteps = time_shift(mu, 1.0, timesteps) |
|
|
|
return timesteps.tolist() |
|
|
|
|
|
def denoise( |
|
model, |
|
img: torch.Tensor, |
|
img_ids: torch.Tensor, |
|
txt: torch.Tensor, |
|
txt_ids: torch.Tensor, |
|
vec: torch.Tensor, |
|
timesteps: list[float], |
|
guidance: float = 4.0, |
|
t5_attn_mask: Optional[torch.Tensor] = None, |
|
): |
|
|
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) |
|
comfy_pbar = comfy.utils.ProgressBar(total=len(timesteps)) |
|
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): |
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) |
|
pred = model( |
|
img=img, |
|
img_ids=img_ids, |
|
txt=txt, |
|
txt_ids=txt_ids, |
|
y=vec, |
|
timesteps=t_vec, |
|
guidance=guidance_vec, |
|
txt_attention_mask=t5_attn_mask, |
|
) |
|
img = img + (t_prev - t_curr) * pred |
|
comfy_pbar.update(1) |
|
|
|
return img |
|
def do_sample( |
|
accelerator: Optional[accelerate.Accelerator], |
|
model, |
|
img: torch.Tensor, |
|
img_ids: torch.Tensor, |
|
l_pooled: torch.Tensor, |
|
t5_out: torch.Tensor, |
|
txt_ids: torch.Tensor, |
|
num_steps: int, |
|
guidance: float, |
|
t5_attn_mask: Optional[torch.Tensor], |
|
is_schnell: bool, |
|
device: torch.device, |
|
flux_dtype: torch.dtype, |
|
): |
|
timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell) |
|
print(timesteps) |
|
|
|
|
|
if accelerator: |
|
with accelerator.autocast(), torch.no_grad(): |
|
x = denoise( |
|
model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask |
|
) |
|
else: |
|
with torch.autocast(device_type=device.type, dtype=flux_dtype): |
|
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) |
|
with torch.autocast(device_type=device.type, dtype=flux_dtype): |
|
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( |
|
tokenize_strategy, [None, t5xxl], tokens_and_masks, apply_t5_attn_mask |
|
) |
|
|
|
return x |
|
|
|
x = do_sample(accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance_scale, t5_attn_mask, False, device, dtype) |
|
|
|
model = model.cpu() |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
x = x.float() |
|
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2) |
|
|
|
|
|
logger.info("Decoding image...") |
|
ae = ae.to(device) |
|
with torch.no_grad(): |
|
if use_fp8: |
|
with accelerator.autocast(): |
|
x = ae.decode(x) |
|
else: |
|
with torch.autocast(device_type=device.type, dtype=ae_dtype): |
|
x = ae.decode(x) |
|
|
|
ae = ae.cpu() |
|
|
|
x = x.clamp(-1, 1) |
|
x = x.permute(0, 2, 3, 1) |
|
|
|
return ((0.5 * (x + 1.0)).cpu().float(),) |
|
|
|
class UploadToHuggingFace: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"network_trainer": ("NETWORKTRAINER",), |
|
"source_path": ("STRING", {"default": ""}), |
|
"repo_id": ("STRING",{"default": ""}), |
|
"revision": ("STRING", {"default": ""}), |
|
"private": ("BOOLEAN", {"default": True, "tooltip": "If creating a new repo, leave it private"}), |
|
}, |
|
"optional": { |
|
"token": ("STRING", {"default": "","tooltip":"DO NOT LEAVE IN THE NODE or it might save in metadata, can also use the hf_token.json"}), |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("NETWORKTRAINER", "STRING",) |
|
RETURN_NAMES = ("network_trainer","status",) |
|
FUNCTION = "upload" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def upload(self, source_path, network_trainer, repo_id, private, revision, token=""): |
|
with torch.inference_mode(False): |
|
from huggingface_hub import HfApi |
|
|
|
if not token: |
|
with open(os.path.join(script_directory, "hf_token.json"), "r") as file: |
|
token_data = json.load(file) |
|
token = token_data["hf_token"] |
|
print(token) |
|
|
|
|
|
directory_path = os.path.dirname(os.path.dirname(source_path)) |
|
file_name = os.path.basename(source_path) |
|
|
|
metadata = network_trainer["network_trainer"].metadata |
|
metadata_file_path = os.path.join(directory_path, "metadata.json") |
|
with open(metadata_file_path, 'w') as f: |
|
json.dump(metadata, f, indent=4) |
|
|
|
repo_type = None |
|
api = HfApi(token=token) |
|
|
|
try: |
|
api.repo_info( |
|
repo_id=repo_id, |
|
revision=revision if revision != "" else None, |
|
repo_type=repo_type) |
|
repo_exists = True |
|
logger.info(f"Repository {repo_id} exists.") |
|
except Exception as e: |
|
repo_exists = False |
|
logger.error(f"Repository {repo_id} does not exist. Exception: {e}") |
|
|
|
if not repo_exists: |
|
try: |
|
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private) |
|
except Exception as e: |
|
logger.error("===========================================") |
|
logger.error(f"failed to create HuggingFace repo: {e}") |
|
logger.error("===========================================") |
|
|
|
is_folder = (type(source_path) == str and os.path.isdir(source_path)) or (isinstance(source_path, Path) and source_path.is_dir()) |
|
print(source_path, is_folder) |
|
|
|
try: |
|
if is_folder: |
|
api.upload_folder( |
|
repo_id=repo_id, |
|
repo_type=repo_type, |
|
folder_path=source_path, |
|
path_in_repo=file_name, |
|
) |
|
else: |
|
api.upload_file( |
|
repo_id=repo_id, |
|
repo_type=repo_type, |
|
path_or_fileobj=source_path, |
|
path_in_repo=file_name, |
|
) |
|
|
|
if not is_folder: |
|
api.upload_file( |
|
repo_id=repo_id, |
|
repo_type=repo_type, |
|
path_or_fileobj=str(metadata_file_path), |
|
path_in_repo='metadata.json', |
|
) |
|
status = "Uploaded to HuggingFace succesfully" |
|
except Exception as e: |
|
logger.error("===========================================") |
|
logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}") |
|
logger.error("===========================================") |
|
status = f"Failed to upload to HuggingFace {e}" |
|
|
|
return (network_trainer, status,) |
|
|
|
class ExtractFluxLoRA: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"original_model": (folder_paths.get_filename_list("unet"), ), |
|
"finetuned_model": (folder_paths.get_filename_list("unet"), ), |
|
"output_path": ("STRING", {"default": f"{str(os.path.join(folder_paths.models_dir, 'loras', 'Flux'))}"}), |
|
"dim": ("INT", {"default": 4, "min": 2, "max": 1024, "step": 2, "tooltip": "LoRA rank"}), |
|
"save_dtype": (["fp32", "fp16", "bf16", "fp8_e4m3fn", "fp8_e5m2"], {"default": "bf16", "tooltip": "the dtype to save the LoRA as"}), |
|
"load_device": (["cpu", "cuda"], {"default": "cuda", "tooltip": "the device to load the model to"}), |
|
"store_device": (["cpu", "cuda"], {"default": "cpu", "tooltip": "the device to store the LoRA as"}), |
|
"clamp_quantile": ("FLOAT", {"default": 0.99, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "clamp quantile"}), |
|
"metadata": ("BOOLEAN", {"default": True, "tooltip": "build metadata"}), |
|
"mem_eff_safe_open": ("BOOLEAN", {"default": False, "tooltip": "memory efficient loading"}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("STRING", ) |
|
RETURN_NAMES = ("output_path",) |
|
FUNCTION = "extract" |
|
CATEGORY = "FluxTrainer" |
|
|
|
def extract(self, original_model, finetuned_model, output_path, dim, save_dtype, load_device, store_device, clamp_quantile, metadata, mem_eff_safe_open): |
|
from .flux_extract_lora import svd |
|
transformer_path = folder_paths.get_full_path("unet", original_model) |
|
finetuned_model_path = folder_paths.get_full_path("unet", finetuned_model) |
|
outpath = svd( |
|
model_org = transformer_path, |
|
model_tuned = finetuned_model_path, |
|
save_to = os.path.join(output_path, f"{finetuned_model.replace('.safetensors', '')}_extracted_lora_rank_{dim}-{save_dtype}.safetensors"), |
|
dim = dim, |
|
device = load_device, |
|
store_device = store_device, |
|
save_precision = save_dtype, |
|
clamp_quantile = clamp_quantile, |
|
no_metadata = not metadata, |
|
mem_eff_safe_open = mem_eff_safe_open |
|
) |
|
|
|
return (outpath,) |
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"InitFluxLoRATraining": InitFluxLoRATraining, |
|
"InitFluxTraining": InitFluxTraining, |
|
"FluxTrainModelSelect": FluxTrainModelSelect, |
|
"TrainDatasetGeneralConfig": TrainDatasetGeneralConfig, |
|
"TrainDatasetAdd": TrainDatasetAdd, |
|
"FluxTrainLoop": FluxTrainLoop, |
|
"VisualizeLoss": VisualizeLoss, |
|
"FluxTrainValidate": FluxTrainValidate, |
|
"FluxTrainValidationSettings": FluxTrainValidationSettings, |
|
"FluxTrainEnd": FluxTrainEnd, |
|
"FluxTrainSave": FluxTrainSave, |
|
"FluxKohyaInferenceSampler": FluxKohyaInferenceSampler, |
|
"UploadToHuggingFace": UploadToHuggingFace, |
|
"OptimizerConfig": OptimizerConfig, |
|
"OptimizerConfigAdafactor": OptimizerConfigAdafactor, |
|
"FluxTrainSaveModel": FluxTrainSaveModel, |
|
"ExtractFluxLoRA": ExtractFluxLoRA, |
|
"OptimizerConfigProdigy": OptimizerConfigProdigy, |
|
"FluxTrainResume": FluxTrainResume, |
|
"FluxTrainBlockSelect": FluxTrainBlockSelect, |
|
"TrainDatasetRegularization": TrainDatasetRegularization, |
|
"FluxTrainAndValidateLoop": FluxTrainAndValidateLoop, |
|
"OptimizerConfigProdigyPlusScheduleFree": OptimizerConfigProdigyPlusScheduleFree, |
|
"FluxTrainerLossConfig": FluxTrainerLossConfig, |
|
"TrainNetworkConfig": TrainNetworkConfig, |
|
} |
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
"InitFluxLoRATraining": "Init Flux LoRA Training", |
|
"InitFluxTraining": "Init Flux Training", |
|
"FluxTrainModelSelect": "FluxTrain ModelSelect", |
|
"TrainDatasetGeneralConfig": "TrainDatasetGeneralConfig", |
|
"TrainDatasetAdd": "TrainDatasetAdd", |
|
"FluxTrainLoop": "Flux Train Loop", |
|
"VisualizeLoss": "Visualize Loss", |
|
"FluxTrainValidate": "Flux Train Validate", |
|
"FluxTrainValidationSettings": "Flux Train Validation Settings", |
|
"FluxTrainEnd": "Flux LoRA Train End", |
|
"FluxTrainSave": "Flux Train Save LoRA", |
|
"FluxKohyaInferenceSampler": "Flux Kohya Inference Sampler", |
|
"UploadToHuggingFace": "Upload To HuggingFace", |
|
"OptimizerConfig": "Optimizer Config", |
|
"OptimizerConfigAdafactor": "Optimizer Config Adafactor", |
|
"FluxTrainSaveModel": "Flux Train Save Model", |
|
"ExtractFluxLoRA": "Extract Flux LoRA", |
|
"OptimizerConfigProdigy": "Optimizer Config Prodigy", |
|
"FluxTrainResume": "Flux Train Resume", |
|
"FluxTrainBlockSelect": "Flux Train Block Select", |
|
"TrainDatasetRegularization": "Train Dataset Regularization", |
|
"FluxTrainAndValidateLoop": "Flux Train And Validate Loop", |
|
"OptimizerConfigProdigyPlusScheduleFree": "Optimizer Config ProdigyPlusScheduleFree", |
|
"FluxTrainerLossConfig": "Flux Trainer Loss Config", |
|
"TrainNetworkConfig": "Train Network Config", |
|
} |
|
|