JasonSmithSO's picture
Upload 578 files
8866644 verified
"""
Adapted from comfyui CLIP code.
https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/sd1_clip.py
"""
import os
from transformers import T5Tokenizer, T5EncoderModel, T5Config, modeling_utils
import torch
import traceback
import zipfile
from comfy import model_management
from comfy.sd1_clip import parse_parentheses, token_weights, escape_important, unescape_important, safe_load_embed_zip, expand_directory_list, load_embed
class T5v11Model(torch.nn.Module):
def __init__(self, textmodel_ver="xxl", textmodel_json_config=None, textmodel_path=None, device="cpu", max_length=120, freeze=True, dtype=None):
super().__init__()
self.num_layers = 24
self.max_length = max_length
self.bnb = False
if textmodel_path is not None:
model_args = {}
model_args["low_cpu_mem_usage"] = True # Don't take 2x system ram on cpu
if dtype == "bnb8bit":
self.bnb = True
model_args["load_in_8bit"] = True
elif dtype == "bnb4bit":
self.bnb = True
model_args["load_in_4bit"] = True
else:
if dtype: model_args["torch_dtype"] = dtype
self.bnb = False
# second GPU offload hack part 2
if device.startswith("cuda"):
model_args["device_map"] = device
print(f"Loading T5 from '{textmodel_path}'")
self.transformer = T5EncoderModel.from_pretrained(textmodel_path, **model_args)
else:
if textmodel_json_config is None:
textmodel_json_config = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
f"t5v11-{textmodel_ver}_config.json"
)
config = T5Config.from_json_file(textmodel_json_config)
self.num_layers = config.num_hidden_layers
with modeling_utils.no_init_weights():
self.transformer = T5EncoderModel(config)
if freeze:
self.freeze()
self.empty_tokens = [[0] * self.max_length] # <pad> token
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, tokens):
device = self.transformer.get_input_embeddings().weight.device
tokens = torch.LongTensor(tokens).to(device)
attention_mask = torch.zeros_like(tokens)
max_token = 1 # </s> token
for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == max_token:
break
outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask)
z = outputs['last_hidden_state']
z.detach().cpu().float()
return z
def encode(self, tokens):
return self(tokens)
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False)
def to(self, *args, **kwargs):
"""BNB complains if you try to change the device or dtype"""
if self.bnb:
print("Thanks to BitsAndBytes, T5 becomes an immovable rock.", args, kwargs)
else:
self.transformer.to(*args, **kwargs)
def encode_token_weights(self, token_weight_pairs, return_padded=False):
to_encode = list(self.empty_tokens)
for x in token_weight_pairs:
tokens = list(map(lambda a: a[0], x))
to_encode.append(tokens)
out = self.encode(to_encode)
z_empty = out[0:1]
output = []
for k in range(1, out.shape[0]):
z = out[k:k+1]
for i in range(len(z)):
for j in range(len(z[i])):
weight = token_weight_pairs[k - 1][j][1]
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j]
output.append(z)
if (len(output) == 0):
return z_empty.cpu()
out = torch.cat(output, dim=-2)
if not return_padded:
# Count number of tokens that aren't <pad>, then use that number as an index.
keep_index = sum([sum([1 for y in x if y[0] != 0]) for x in token_weight_pairs])
out = out[:, :keep_index, :]
return out
class T5v11Tokenizer:
"""
This is largely just based on the ComfyUI CLIP code.
"""
def __init__(self, tokenizer_path=None, max_length=120, embedding_directory=None, embedding_size=4096, embedding_key='t5'):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
self.tokenizer = T5Tokenizer.from_pretrained(tokenizer_path)
self.max_length = max_length
self.max_tokens_per_section = self.max_length - 1 # </s> but no <BOS>
self.pad_token = self.tokenizer("<pad>", add_special_tokens=False)["input_ids"][0]
self.end_token = self.tokenizer("</s>", add_special_tokens=False)["input_ids"][0]
vocab = self.tokenizer.get_vocab()
self.inv_vocab = {v: k for k, v in vocab.items()}
self.embedding_directory = embedding_directory
self.max_word_length = 8 # haven't verified this
self.embedding_identifier = "embedding:"
self.embedding_size = embedding_size
self.embedding_key = embedding_key
def _try_get_embedding(self, embedding_name:str):
'''
Takes a potential embedding name and tries to retrieve it.
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
'''
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
if embed is None:
stripped = embedding_name.strip(',')
if len(stripped) < len(embedding_name):
embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
return (embed, embedding_name[len(stripped):])
return (embed, "")
def tokenize_with_weights(self, text:str, return_word_ids=False):
'''
Takes a prompt and converts it to a list of (token, weight, word id) elements.
Tokens can both be integer tokens and pre computed T5 tensors.
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
Returned list has the dimensions NxM where M is the input size of T5
'''
pad_token = self.pad_token
text = escape_important(text)
parsed_weights = token_weights(text, 1.0)
#tokenize words
tokens = []
for weighted_segment, weight in parsed_weights:
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
to_tokenize = [x for x in to_tokenize if x != ""]
for word in to_tokenize:
#if we find an embedding, deal with the embedding
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
embedding_name = word[len(self.embedding_identifier):].strip('\n')
embed, leftover = self._try_get_embedding(embedding_name)
if embed is None:
print(f"warning, embedding:{embedding_name} does not exist, ignoring")
else:
if len(embed.shape) == 1:
tokens.append([(embed, weight)])
else:
tokens.append([(embed[x], weight) for x in range(embed.shape[0])])
#if we accidentally have leftover text, continue parsing using leftover, else move on to next word
if leftover != "":
word = leftover
else:
continue
#parse word
tokens.append([(t, weight) for t in self.tokenizer(word, add_special_tokens=False)["input_ids"]])
#reshape token array to T5 input size
batched_tokens = []
batch = []
batched_tokens.append(batch)
for i, t_group in enumerate(tokens):
#determine if we're going to try and keep the tokens in a single batch
is_large = len(t_group) >= self.max_word_length
while len(t_group) > 0:
if len(t_group) + len(batch) > self.max_length - 1:
remaining_length = self.max_length - len(batch) - 1
#break word in two and add end token
if is_large:
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
batch.append((self.end_token, 1.0, 0))
t_group = t_group[remaining_length:]
#add end token and pad
else:
batch.append((self.end_token, 1.0, 0))
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
#start new batch
batch = []
batched_tokens.append(batch)
else:
batch.extend([(t,w,i+1) for t,w in t_group])
t_group = []
# fill last batch
batch.extend([(self.end_token, 1.0, 0)] + [(self.pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1))
# instead of filling, just add EOS (DEBUG)
# batch.extend([(self.end_token, 1.0, 0)])
if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
return batched_tokens
def untokenize(self, token_weight_pair):
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))