JasonSmithSO's picture
Upload 578 files
8866644 verified
import os
import json
import torch
import folder_paths
from .loader import load_t5
from ..utils.dtype import string_to_dtype
# initialize custom folder path
os.makedirs(
os.path.join(folder_paths.models_dir,"t5"),
exist_ok = True,
)
folder_paths.folder_names_and_paths["t5"] = (
[
os.path.join(folder_paths.models_dir,"t5"),
*folder_paths.folder_names_and_paths.get("t5", [[],set()])[0]
],
folder_paths.supported_pt_extensions
)
dtypes = [
"default",
"auto (comfy)",
"FP32",
"FP16",
# Note: remove these at some point
"bnb8bit",
"bnb4bit",
]
try: torch.float8_e5m2
except AttributeError: print("Torch version too old for FP8")
else: dtypes += ["FP8 E4M3", "FP8 E5M2"]
class T5v11Loader:
@classmethod
def INPUT_TYPES(s):
devices = ["auto", "cpu", "gpu"]
# hack for using second GPU as offload
for k in range(1, torch.cuda.device_count()):
devices.append(f"cuda:{k}")
return {
"required": {
"t5v11_name": (folder_paths.get_filename_list("t5"),),
"t5v11_ver": (["xxl"],),
"path_type": (["folder", "file"],),
"device": (devices, {"default":"cpu"}),
"dtype": (dtypes,),
}
}
RETURN_TYPES = ("T5",)
FUNCTION = "load_model"
CATEGORY = "ExtraModels/T5"
TITLE = "T5v1.1 Loader"
def load_model(self, t5v11_name, t5v11_ver, path_type, device, dtype):
if "bnb" in dtype:
assert device == "gpu" or device.startswith("cuda"), "BitsAndBytes only works on CUDA! Set device to 'gpu'."
dtype = string_to_dtype(dtype, "text_encoder")
if device == "cpu":
assert dtype in [None, torch.float32], f"Can't use dtype '{dtype}' with CPU! Set dtype to 'default'."
return (load_t5(
model_type = "t5v11",
model_ver = t5v11_ver,
model_path = folder_paths.get_full_path("t5", t5v11_name),
path_type = path_type,
device = device,
dtype = dtype,
),)
class T5TextEncode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"text": ("STRING", {"multiline": True}),
"T5": ("T5",),
}
}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "ExtraModels/T5"
TITLE = "T5 Text Encode"
def encode(self, text, T5=None):
tokens = T5.tokenize(text)
cond = T5.encode_from_tokens(tokens)
return ([[cond, {}]], )
NODE_CLASS_MAPPINGS = {
"T5v11Loader" : T5v11Loader,
"T5TextEncode" : T5TextEncode,
}