Spaces:
Configuration error
Configuration error
File size: 9,885 Bytes
8866644 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
"""
Adapted from comfyui CLIP code.
https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/sd1_clip.py
"""
import os
from transformers import T5Tokenizer, T5EncoderModel, T5Config, modeling_utils
import torch
import traceback
import zipfile
from comfy import model_management
from comfy.sd1_clip import parse_parentheses, token_weights, escape_important, unescape_important, safe_load_embed_zip, expand_directory_list, load_embed
class T5v11Model(torch.nn.Module):
def __init__(self, textmodel_ver="xxl", textmodel_json_config=None, textmodel_path=None, device="cpu", max_length=120, freeze=True, dtype=None):
super().__init__()
self.num_layers = 24
self.max_length = max_length
self.bnb = False
if textmodel_path is not None:
model_args = {}
model_args["low_cpu_mem_usage"] = True # Don't take 2x system ram on cpu
if dtype == "bnb8bit":
self.bnb = True
model_args["load_in_8bit"] = True
elif dtype == "bnb4bit":
self.bnb = True
model_args["load_in_4bit"] = True
else:
if dtype: model_args["torch_dtype"] = dtype
self.bnb = False
# second GPU offload hack part 2
if device.startswith("cuda"):
model_args["device_map"] = device
print(f"Loading T5 from '{textmodel_path}'")
self.transformer = T5EncoderModel.from_pretrained(textmodel_path, **model_args)
else:
if textmodel_json_config is None:
textmodel_json_config = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
f"t5v11-{textmodel_ver}_config.json"
)
config = T5Config.from_json_file(textmodel_json_config)
self.num_layers = config.num_hidden_layers
with modeling_utils.no_init_weights():
self.transformer = T5EncoderModel(config)
if freeze:
self.freeze()
self.empty_tokens = [[0] * self.max_length] # <pad> token
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, tokens):
device = self.transformer.get_input_embeddings().weight.device
tokens = torch.LongTensor(tokens).to(device)
attention_mask = torch.zeros_like(tokens)
max_token = 1 # </s> token
for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == max_token:
break
outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask)
z = outputs['last_hidden_state']
z.detach().cpu().float()
return z
def encode(self, tokens):
return self(tokens)
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False)
def to(self, *args, **kwargs):
"""BNB complains if you try to change the device or dtype"""
if self.bnb:
print("Thanks to BitsAndBytes, T5 becomes an immovable rock.", args, kwargs)
else:
self.transformer.to(*args, **kwargs)
def encode_token_weights(self, token_weight_pairs, return_padded=False):
to_encode = list(self.empty_tokens)
for x in token_weight_pairs:
tokens = list(map(lambda a: a[0], x))
to_encode.append(tokens)
out = self.encode(to_encode)
z_empty = out[0:1]
output = []
for k in range(1, out.shape[0]):
z = out[k:k+1]
for i in range(len(z)):
for j in range(len(z[i])):
weight = token_weight_pairs[k - 1][j][1]
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j]
output.append(z)
if (len(output) == 0):
return z_empty.cpu()
out = torch.cat(output, dim=-2)
if not return_padded:
# Count number of tokens that aren't <pad>, then use that number as an index.
keep_index = sum([sum([1 for y in x if y[0] != 0]) for x in token_weight_pairs])
out = out[:, :keep_index, :]
return out
class T5v11Tokenizer:
"""
This is largely just based on the ComfyUI CLIP code.
"""
def __init__(self, tokenizer_path=None, max_length=120, embedding_directory=None, embedding_size=4096, embedding_key='t5'):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
self.tokenizer = T5Tokenizer.from_pretrained(tokenizer_path)
self.max_length = max_length
self.max_tokens_per_section = self.max_length - 1 # </s> but no <BOS>
self.pad_token = self.tokenizer("<pad>", add_special_tokens=False)["input_ids"][0]
self.end_token = self.tokenizer("</s>", add_special_tokens=False)["input_ids"][0]
vocab = self.tokenizer.get_vocab()
self.inv_vocab = {v: k for k, v in vocab.items()}
self.embedding_directory = embedding_directory
self.max_word_length = 8 # haven't verified this
self.embedding_identifier = "embedding:"
self.embedding_size = embedding_size
self.embedding_key = embedding_key
def _try_get_embedding(self, embedding_name:str):
'''
Takes a potential embedding name and tries to retrieve it.
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
'''
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
if embed is None:
stripped = embedding_name.strip(',')
if len(stripped) < len(embedding_name):
embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
return (embed, embedding_name[len(stripped):])
return (embed, "")
def tokenize_with_weights(self, text:str, return_word_ids=False):
'''
Takes a prompt and converts it to a list of (token, weight, word id) elements.
Tokens can both be integer tokens and pre computed T5 tensors.
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
Returned list has the dimensions NxM where M is the input size of T5
'''
pad_token = self.pad_token
text = escape_important(text)
parsed_weights = token_weights(text, 1.0)
#tokenize words
tokens = []
for weighted_segment, weight in parsed_weights:
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
to_tokenize = [x for x in to_tokenize if x != ""]
for word in to_tokenize:
#if we find an embedding, deal with the embedding
if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
embedding_name = word[len(self.embedding_identifier):].strip('\n')
embed, leftover = self._try_get_embedding(embedding_name)
if embed is None:
print(f"warning, embedding:{embedding_name} does not exist, ignoring")
else:
if len(embed.shape) == 1:
tokens.append([(embed, weight)])
else:
tokens.append([(embed[x], weight) for x in range(embed.shape[0])])
#if we accidentally have leftover text, continue parsing using leftover, else move on to next word
if leftover != "":
word = leftover
else:
continue
#parse word
tokens.append([(t, weight) for t in self.tokenizer(word, add_special_tokens=False)["input_ids"]])
#reshape token array to T5 input size
batched_tokens = []
batch = []
batched_tokens.append(batch)
for i, t_group in enumerate(tokens):
#determine if we're going to try and keep the tokens in a single batch
is_large = len(t_group) >= self.max_word_length
while len(t_group) > 0:
if len(t_group) + len(batch) > self.max_length - 1:
remaining_length = self.max_length - len(batch) - 1
#break word in two and add end token
if is_large:
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
batch.append((self.end_token, 1.0, 0))
t_group = t_group[remaining_length:]
#add end token and pad
else:
batch.append((self.end_token, 1.0, 0))
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
#start new batch
batch = []
batched_tokens.append(batch)
else:
batch.extend([(t,w,i+1) for t,w in t_group])
t_group = []
# fill last batch
batch.extend([(self.end_token, 1.0, 0)] + [(self.pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1))
# instead of filling, just add EOS (DEBUG)
# batch.extend([(self.end_token, 1.0, 0)])
if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
return batched_tokens
def untokenize(self, token_weight_pair):
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|