|
import torch |
|
from .. import shared |
|
|
|
class Embedding: |
|
def __init__(self, vec, name, step=None): |
|
self.vec = vec |
|
self.name = name |
|
self.step = step |
|
self.shape = None |
|
self.vectors = 0 |
|
self.cached_checksum = None |
|
self.sd_checkpoint = None |
|
self.sd_checkpoint_name = None |
|
self.optimizer_state_dict = None |
|
self.filename = None |
|
|
|
self.shape = vec.shape[-1] |
|
self.vectors = vec.shape[0] |
|
|
|
def save(self, filename): |
|
embedding_data = { |
|
"string_to_token": {"*": 265}, |
|
"string_to_param": {"*": self.vec}, |
|
"name": self.name, |
|
"step": self.step, |
|
"sd_checkpoint": self.sd_checkpoint, |
|
"sd_checkpoint_name": self.sd_checkpoint_name, |
|
} |
|
|
|
torch.save(embedding_data, filename) |
|
|
|
if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None: |
|
optimizer_saved_dict = { |
|
'hash': self.checksum(), |
|
'optimizer_state_dict': self.optimizer_state_dict, |
|
} |
|
torch.save(optimizer_saved_dict, f"{filename}.optim") |
|
|
|
def checksum(self): |
|
if self.cached_checksum is not None: |
|
return self.cached_checksum |
|
|
|
def const_hash(a): |
|
r = 0 |
|
for v in a: |
|
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF |
|
return r |
|
|
|
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}' |
|
return self.cached_checksum |
|
|
|
class EmbeddingDatabase: |
|
def __init__(self): |
|
self.ids_lookup = {} |
|
self.word_embeddings = {} |
|
self.skipped_embeddings = {} |
|
self.expected_shape = -1 |
|
self.embedding_dirs = {} |
|
self.previously_displayed_embeddings = () |
|
|
|
def register_embedding(self, embedding, model): |
|
self.word_embeddings[embedding.name] = embedding |
|
|
|
ids = model.tokenize([embedding.name])[0] |
|
|
|
first_id = ids[0] |
|
if first_id not in self.ids_lookup: |
|
self.ids_lookup[first_id] = [] |
|
|
|
self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True) |
|
|
|
return embedding |
|
|
|
def find_embedding_at_position(self, tokens, offset): |
|
token = tokens[offset] |
|
possible_matches = self.ids_lookup.get(token, None) |
|
|
|
if possible_matches is None: |
|
return None, None |
|
|
|
for ids, embedding in possible_matches: |
|
if tokens[offset:offset + len(ids)] == ids: |
|
return embedding, len(ids) |
|
|
|
return None, None |