|  |  | 
					
						
						|  |  | 
					
						
						|  | import os | 
					
						
						|  | import logging | 
					
						
						|  |  | 
					
						
						|  | logger = logging.getLogger(__name__) | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def cal_cross_attn(to_q, to_k, to_v, rand_input): | 
					
						
						|  | hidden_dim, embed_dim = to_q.shape | 
					
						
						|  | attn_to_q = nn.Linear(hidden_dim, embed_dim, bias=False) | 
					
						
						|  | attn_to_k = nn.Linear(hidden_dim, embed_dim, bias=False) | 
					
						
						|  | attn_to_v = nn.Linear(hidden_dim, embed_dim, bias=False) | 
					
						
						|  | attn_to_q.load_state_dict({"weight": to_q}) | 
					
						
						|  | attn_to_k.load_state_dict({"weight": to_k}) | 
					
						
						|  | attn_to_v.load_state_dict({"weight": to_v}) | 
					
						
						|  |  | 
					
						
						|  | return torch.einsum( | 
					
						
						|  | "ik, jk -> ik", | 
					
						
						|  | F.softmax( | 
					
						
						|  | torch.einsum("ij, kj -> ik", attn_to_q(rand_input), attn_to_k(rand_input)), | 
					
						
						|  | dim=-1, | 
					
						
						|  | ), | 
					
						
						|  | attn_to_v(rand_input), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def model_hash(filename): | 
					
						
						|  | try: | 
					
						
						|  | with open(filename, "rb") as file: | 
					
						
						|  | import hashlib | 
					
						
						|  |  | 
					
						
						|  | m = hashlib.sha256() | 
					
						
						|  |  | 
					
						
						|  | file.seek(0x100000) | 
					
						
						|  | m.update(file.read(0x10000)) | 
					
						
						|  | return m.hexdigest()[0:8] | 
					
						
						|  | except FileNotFoundError: | 
					
						
						|  | return "NOFILE" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def eval(model, n, input): | 
					
						
						|  | qk = f"enc_p.encoder.attn_layers.{n}.conv_q.weight" | 
					
						
						|  | uk = f"enc_p.encoder.attn_layers.{n}.conv_k.weight" | 
					
						
						|  | vk = f"enc_p.encoder.attn_layers.{n}.conv_v.weight" | 
					
						
						|  | atoq, atok, atov = model[qk][:, :, 0], model[uk][:, :, 0], model[vk][:, :, 0] | 
					
						
						|  |  | 
					
						
						|  | attn = cal_cross_attn(atoq, atok, atov, input) | 
					
						
						|  | return attn | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def main(path, root): | 
					
						
						|  | torch.manual_seed(114514) | 
					
						
						|  | model_a = torch.load(path, map_location="cpu")["weight"] | 
					
						
						|  |  | 
					
						
						|  | logger.info("Query:\t\t%s\t%s" % (path, model_hash(path))) | 
					
						
						|  |  | 
					
						
						|  | map_attn_a = {} | 
					
						
						|  | map_rand_input = {} | 
					
						
						|  | for n in range(6): | 
					
						
						|  | hidden_dim, embed_dim, _ = model_a[ | 
					
						
						|  | f"enc_p.encoder.attn_layers.{n}.conv_v.weight" | 
					
						
						|  | ].shape | 
					
						
						|  | rand_input = torch.randn([embed_dim, hidden_dim]) | 
					
						
						|  |  | 
					
						
						|  | map_attn_a[n] = eval(model_a, n, rand_input) | 
					
						
						|  | map_rand_input[n] = rand_input | 
					
						
						|  |  | 
					
						
						|  | del model_a | 
					
						
						|  |  | 
					
						
						|  | for name in sorted(list(os.listdir(root))): | 
					
						
						|  | path = "%s/%s" % (root, name) | 
					
						
						|  | model_b = torch.load(path, map_location="cpu")["weight"] | 
					
						
						|  |  | 
					
						
						|  | sims = [] | 
					
						
						|  | for n in range(6): | 
					
						
						|  | attn_a = map_attn_a[n] | 
					
						
						|  | attn_b = eval(model_b, n, map_rand_input[n]) | 
					
						
						|  |  | 
					
						
						|  | sim = torch.mean(torch.cosine_similarity(attn_a, attn_b)) | 
					
						
						|  | sims.append(sim) | 
					
						
						|  |  | 
					
						
						|  | logger.info( | 
					
						
						|  | "Reference:\t%s\t%s\t%s" | 
					
						
						|  | % (path, model_hash(path), f"{torch.mean(torch.stack(sims)) * 1e2:.2f}%") | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | query_path = r"assets\weights\mi v3.pth" | 
					
						
						|  | reference_root = r"assets\weights" | 
					
						
						|  | main(query_path, reference_root) | 
					
						
						|  |  |