File size: 3,667 Bytes
8866644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import torch
import comfy.utils
import comfy.model_patcher
from comfy import model_management
import folder_paths

from .t5v11 import T5v11Model, T5v11Tokenizer

class EXM_T5v11:
	def __init__(self, textmodel_ver="xxl", embedding_directory=None, textmodel_path=None, no_init=False, device="cpu", dtype=None):
		if no_init:
			return

		if device == "auto":
			size = 0
			self.load_device = model_management.text_encoder_device()
			self.offload_device = model_management.text_encoder_offload_device()
			self.init_device = "cpu"
		elif dtype == "bnb8bit":
			# BNB doesn't support size enum
			size = 12.4 * (1024**3)
			# Or moving between devices
			self.load_device = model_management.get_torch_device()
			self.offload_device = self.load_device
			self.init_device = self.load_device
		elif dtype == "bnb4bit":
			# This seems to use the same VRAM as 8bit on Pascal?
			size = 6.2 * (1024**3)
			self.load_device = model_management.get_torch_device()
			self.offload_device = self.load_device
			self.init_device = self.load_device
		elif device == "cpu":
			size = 0
			self.load_device = "cpu"
			self.offload_device = "cpu"
			self.init_device="cpu"
		elif device.startswith("cuda"):
			print("Direct CUDA device override!\nVRAM will not be freed by default.")
			size = 0
			self.load_device = device
			self.offload_device = device
			self.init_device = device
		else:
			size = 0
			self.load_device = model_management.get_torch_device()
			self.offload_device = "cpu"
			self.init_device="cpu"

		self.cond_stage_model = T5v11Model(
			textmodel_ver  = textmodel_ver,
			textmodel_path = textmodel_path,
			device         = device,
			dtype          = dtype,
		)
		self.tokenizer = T5v11Tokenizer(embedding_directory=embedding_directory)
		self.patcher = comfy.model_patcher.ModelPatcher(
			self.cond_stage_model,
			load_device    = self.load_device,
			offload_device = self.offload_device,
			current_device = self.load_device,
			size           = size,
		)

	def clone(self):
		n = T5(no_init=True)
		n.patcher = self.patcher.clone()
		n.cond_stage_model = self.cond_stage_model
		n.tokenizer = self.tokenizer
		return n

	def tokenize(self, text, return_word_ids=False):
		return self.tokenizer.tokenize_with_weights(text, return_word_ids)

	def encode_from_tokens(self, tokens):
		self.load_model()
		return self.cond_stage_model.encode_token_weights(tokens)

	def encode(self, text):
		tokens = self.tokenize(text)
		return self.encode_from_tokens(tokens)

	def load_sd(self, sd):
		return self.cond_stage_model.load_sd(sd)

	def get_sd(self):
		return self.cond_stage_model.state_dict()

	def load_model(self):
		if self.load_device != "cpu":
			model_management.load_model_gpu(self.patcher)
		return self.patcher

	def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
		return self.patcher.add_patches(patches, strength_patch, strength_model)

	def get_key_patches(self):
		return self.patcher.get_key_patches()


def load_t5(model_type, model_ver, model_path, path_type="file", device="cpu", dtype=None):
	assert model_type in ["t5v11"] # Only supported model for now
	model_args = {
		"textmodel_ver" : model_ver,
		"device" : device,
		"dtype"  : dtype,
	}

	if path_type == "folder":
		# pass directly to transformers and initialize there
		# this is to avoid having to handle multi-file state dict loading for now.
		model_args["textmodel_path"] = os.path.dirname(model_path)
		return EXM_T5v11(**model_args)
	else:
		# for some reason this returns garbage with torch.int8 weights, or just OOMs
		model = EXM_T5v11(**model_args)
		sd = comfy.utils.load_torch_file(model_path)
		model.load_sd(sd)
		return model