File size: 9,885 Bytes
8866644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
"""
Adapted from comfyui CLIP code.
https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/sd1_clip.py
"""

import os

from transformers import T5Tokenizer, T5EncoderModel, T5Config, modeling_utils
import torch
import traceback
import zipfile
from comfy import model_management

from comfy.sd1_clip import parse_parentheses, token_weights, escape_important, unescape_important, safe_load_embed_zip, expand_directory_list, load_embed

class T5v11Model(torch.nn.Module):
    def __init__(self, textmodel_ver="xxl", textmodel_json_config=None, textmodel_path=None, device="cpu", max_length=120, freeze=True, dtype=None):
        super().__init__()

        self.num_layers = 24
        self.max_length = max_length
        self.bnb = False

        if textmodel_path is not None:
            model_args = {}
            model_args["low_cpu_mem_usage"] = True # Don't take 2x system ram on cpu
            if dtype == "bnb8bit":
                self.bnb = True
                model_args["load_in_8bit"] = True
            elif dtype == "bnb4bit":
                self.bnb = True
                model_args["load_in_4bit"] = True
            else:
                if dtype: model_args["torch_dtype"] = dtype
                self.bnb = False
            # second GPU offload hack part 2
            if device.startswith("cuda"):
                model_args["device_map"] = device
            print(f"Loading T5 from '{textmodel_path}'")
            self.transformer = T5EncoderModel.from_pretrained(textmodel_path, **model_args)
        else:
            if textmodel_json_config is None:
                textmodel_json_config = os.path.join(
                    os.path.dirname(os.path.realpath(__file__)),
                    f"t5v11-{textmodel_ver}_config.json"
                )
            config = T5Config.from_json_file(textmodel_json_config)
            self.num_layers = config.num_hidden_layers
            with modeling_utils.no_init_weights():
                self.transformer = T5EncoderModel(config)

        if freeze:
            self.freeze()
        self.empty_tokens = [[0] * self.max_length] # <pad> token

    def freeze(self):
        self.transformer = self.transformer.eval()
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, tokens):
        device = self.transformer.get_input_embeddings().weight.device
        tokens = torch.LongTensor(tokens).to(device)
        attention_mask = torch.zeros_like(tokens)
        max_token = 1 # </s> token
        for x in range(attention_mask.shape[0]):
            for y in range(attention_mask.shape[1]):
                attention_mask[x, y] = 1
                if tokens[x, y] == max_token:
                    break

        outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask)

        z = outputs['last_hidden_state']
        z.detach().cpu().float()
        return z

    def encode(self, tokens):
        return self(tokens)

    def load_sd(self, sd):
        return self.transformer.load_state_dict(sd, strict=False)

    def to(self, *args, **kwargs):
        """BNB complains if you try to change the device or dtype"""
        if self.bnb:
            print("Thanks to BitsAndBytes, T5 becomes an immovable rock.", args, kwargs)
        else:
            self.transformer.to(*args, **kwargs)

    def encode_token_weights(self, token_weight_pairs, return_padded=False):
        to_encode = list(self.empty_tokens)
        for x in token_weight_pairs:
            tokens = list(map(lambda a: a[0], x))
            to_encode.append(tokens)

        out = self.encode(to_encode)
        z_empty = out[0:1]

        output = []
        for k in range(1, out.shape[0]):
            z = out[k:k+1]
            for i in range(len(z)):
                for j in range(len(z[i])):
                    weight = token_weight_pairs[k - 1][j][1]
                    z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j]
            output.append(z)

        if (len(output) == 0):
            return z_empty.cpu()

        out = torch.cat(output, dim=-2)
        if not return_padded:
            # Count number of tokens that aren't <pad>, then use that number as an index.
            keep_index = sum([sum([1 for y in x if y[0] != 0]) for x in token_weight_pairs])
            out = out[:, :keep_index, :]
        return out


class T5v11Tokenizer:
    """
    This is largely just based on the ComfyUI CLIP code.
    """
    def __init__(self, tokenizer_path=None, max_length=120, embedding_directory=None, embedding_size=4096, embedding_key='t5'):
        if tokenizer_path is None:
            tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
        self.tokenizer = T5Tokenizer.from_pretrained(tokenizer_path)
        self.max_length = max_length
        self.max_tokens_per_section = self.max_length - 1 # </s> but no <BOS>

        self.pad_token = self.tokenizer("<pad>", add_special_tokens=False)["input_ids"][0]
        self.end_token = self.tokenizer("</s>", add_special_tokens=False)["input_ids"][0]
        vocab = self.tokenizer.get_vocab()
        self.inv_vocab = {v: k for k, v in vocab.items()}
        self.embedding_directory = embedding_directory
        self.max_word_length = 8 # haven't verified this
        self.embedding_identifier = "embedding:"
        self.embedding_size = embedding_size
        self.embedding_key = embedding_key

    def _try_get_embedding(self, embedding_name:str):
        '''
        Takes a potential embedding name and tries to retrieve it.
        Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
        '''
        embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
        if embed is None:
            stripped = embedding_name.strip(',')
            if len(stripped) < len(embedding_name):
                embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
                return (embed, embedding_name[len(stripped):])
        return (embed, "")

    def tokenize_with_weights(self, text:str, return_word_ids=False):
        '''
        Takes a prompt and converts it to a list of (token, weight, word id) elements.
        Tokens can both be integer tokens and pre computed T5 tensors.
        Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
        Returned list has the dimensions NxM where M is the input size of T5
        '''
        pad_token = self.pad_token
        text = escape_important(text)
        parsed_weights = token_weights(text, 1.0)

        #tokenize words
        tokens = []
        for weighted_segment, weight in parsed_weights:
            to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
            to_tokenize = [x for x in to_tokenize if x != ""]
            for word in to_tokenize:
                #if we find an embedding, deal with the embedding
                if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
                    embedding_name = word[len(self.embedding_identifier):].strip('\n')
                    embed, leftover = self._try_get_embedding(embedding_name)
                    if embed is None:
                        print(f"warning, embedding:{embedding_name} does not exist, ignoring")
                    else:
                        if len(embed.shape) == 1:
                            tokens.append([(embed, weight)])
                        else:
                            tokens.append([(embed[x], weight) for x in range(embed.shape[0])])
                    #if we accidentally have leftover text, continue parsing using leftover, else move on to next word
                    if leftover != "":
                        word = leftover
                    else:
                        continue
                #parse word
                tokens.append([(t, weight) for t in self.tokenizer(word, add_special_tokens=False)["input_ids"]])

        #reshape token array to T5 input size
        batched_tokens = []
        batch = []
        batched_tokens.append(batch)
        for i, t_group in enumerate(tokens):
            #determine if we're going to try and keep the tokens in a single batch
            is_large = len(t_group) >= self.max_word_length

            while len(t_group) > 0:
                if len(t_group) + len(batch) > self.max_length - 1:
                    remaining_length = self.max_length - len(batch) - 1
                    #break word in two and add end token
                    if is_large:
                        batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
                        batch.append((self.end_token, 1.0, 0))
                        t_group = t_group[remaining_length:]
                    #add end token and pad
                    else:
                        batch.append((self.end_token, 1.0, 0))
                        batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
                    #start new batch
                    batch = []
                    batched_tokens.append(batch)
                else:
                    batch.extend([(t,w,i+1) for t,w in t_group])
                    t_group = []

        # fill last batch
        batch.extend([(self.end_token, 1.0, 0)] + [(self.pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1))
        # instead of filling, just add EOS (DEBUG)
        # batch.extend([(self.end_token, 1.0, 0)])

        if not return_word_ids:
            batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
        return batched_tokens

    def untokenize(self, token_weight_pair):
        return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))