File size: 2,316 Bytes
8866644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import json
import torch
import folder_paths

from .loader import load_t5
from ..utils.dtype import string_to_dtype

# initialize custom folder path
os.makedirs(
	os.path.join(folder_paths.models_dir,"t5"),
	exist_ok = True,
)
folder_paths.folder_names_and_paths["t5"] = (
	[
		os.path.join(folder_paths.models_dir,"t5"),
		*folder_paths.folder_names_and_paths.get("t5", [[],set()])[0]
	],
	folder_paths.supported_pt_extensions
)

dtypes = [
	"default",
	"auto (comfy)",
	"FP32",
	"FP16",
	# Note: remove these at some point
	"bnb8bit",
	"bnb4bit",
]
try: torch.float8_e5m2
except AttributeError: print("Torch version too old for FP8")
else: dtypes += ["FP8 E4M3", "FP8 E5M2"]

class T5v11Loader:
	@classmethod
	def INPUT_TYPES(s):
		devices = ["auto", "cpu", "gpu"]
		# hack for using second GPU as offload
		for k in range(1, torch.cuda.device_count()):
			devices.append(f"cuda:{k}")
		return {
			"required": {
				"t5v11_name": (folder_paths.get_filename_list("t5"),),
				"t5v11_ver": (["xxl"],),
				"path_type": (["folder", "file"],),
				"device": (devices, {"default":"cpu"}),
				"dtype": (dtypes,),
			}
		}
	RETURN_TYPES = ("T5",)
	FUNCTION = "load_model"
	CATEGORY = "ExtraModels/T5"
	TITLE = "T5v1.1 Loader"

	def load_model(self, t5v11_name, t5v11_ver, path_type, device, dtype):
		if "bnb" in dtype:
			assert device == "gpu" or device.startswith("cuda"), "BitsAndBytes only works on CUDA! Set device to 'gpu'."
		dtype = string_to_dtype(dtype, "text_encoder")
		if device == "cpu":
			assert dtype in [None, torch.float32], f"Can't use dtype '{dtype}' with CPU! Set dtype to 'default'."

		return (load_t5(
			model_type = "t5v11",
			model_ver  = t5v11_ver,
			model_path = folder_paths.get_full_path("t5", t5v11_name),
			path_type  = path_type,
			device     = device,
			dtype      = dtype,
		),)

class T5TextEncode:
	@classmethod
	def INPUT_TYPES(s):
		return {
			"required": {
				"text": ("STRING", {"multiline": True}),
				"T5": ("T5",),
			}
		}

	RETURN_TYPES = ("CONDITIONING",)
	FUNCTION = "encode"
	CATEGORY = "ExtraModels/T5"
	TITLE = "T5 Text Encode"

	def encode(self, text, T5=None):
		tokens = T5.tokenize(text)
		cond = T5.encode_from_tokens(tokens)
		return ([[cond, {}]], )

NODE_CLASS_MAPPINGS = {
	"T5v11Loader"  : T5v11Loader,
	"T5TextEncode" : T5TextEncode,
}