Spaces:
Configuration error
Configuration error
import os | |
import json | |
import torch | |
import folder_paths | |
from .loader import load_t5 | |
from ..utils.dtype import string_to_dtype | |
# initialize custom folder path | |
os.makedirs( | |
os.path.join(folder_paths.models_dir,"t5"), | |
exist_ok = True, | |
) | |
folder_paths.folder_names_and_paths["t5"] = ( | |
[ | |
os.path.join(folder_paths.models_dir,"t5"), | |
*folder_paths.folder_names_and_paths.get("t5", [[],set()])[0] | |
], | |
folder_paths.supported_pt_extensions | |
) | |
dtypes = [ | |
"default", | |
"auto (comfy)", | |
"FP32", | |
"FP16", | |
# Note: remove these at some point | |
"bnb8bit", | |
"bnb4bit", | |
] | |
try: torch.float8_e5m2 | |
except AttributeError: print("Torch version too old for FP8") | |
else: dtypes += ["FP8 E4M3", "FP8 E5M2"] | |
class T5v11Loader: | |
def INPUT_TYPES(s): | |
devices = ["auto", "cpu", "gpu"] | |
# hack for using second GPU as offload | |
for k in range(1, torch.cuda.device_count()): | |
devices.append(f"cuda:{k}") | |
return { | |
"required": { | |
"t5v11_name": (folder_paths.get_filename_list("t5"),), | |
"t5v11_ver": (["xxl"],), | |
"path_type": (["folder", "file"],), | |
"device": (devices, {"default":"cpu"}), | |
"dtype": (dtypes,), | |
} | |
} | |
RETURN_TYPES = ("T5",) | |
FUNCTION = "load_model" | |
CATEGORY = "ExtraModels/T5" | |
TITLE = "T5v1.1 Loader" | |
def load_model(self, t5v11_name, t5v11_ver, path_type, device, dtype): | |
if "bnb" in dtype: | |
assert device == "gpu" or device.startswith("cuda"), "BitsAndBytes only works on CUDA! Set device to 'gpu'." | |
dtype = string_to_dtype(dtype, "text_encoder") | |
if device == "cpu": | |
assert dtype in [None, torch.float32], f"Can't use dtype '{dtype}' with CPU! Set dtype to 'default'." | |
return (load_t5( | |
model_type = "t5v11", | |
model_ver = t5v11_ver, | |
model_path = folder_paths.get_full_path("t5", t5v11_name), | |
path_type = path_type, | |
device = device, | |
dtype = dtype, | |
),) | |
class T5TextEncode: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"text": ("STRING", {"multiline": True}), | |
"T5": ("T5",), | |
} | |
} | |
RETURN_TYPES = ("CONDITIONING",) | |
FUNCTION = "encode" | |
CATEGORY = "ExtraModels/T5" | |
TITLE = "T5 Text Encode" | |
def encode(self, text, T5=None): | |
tokens = T5.tokenize(text) | |
cond = T5.encode_from_tokens(tokens) | |
return ([[cond, {}]], ) | |
NODE_CLASS_MAPPINGS = { | |
"T5v11Loader" : T5v11Loader, | |
"T5TextEncode" : T5TextEncode, | |
} | |