File size: 1,942 Bytes
726c8e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#!/usr/bin/env -S uv run

# /// script
# requires-python = "<= 3.12"
# dependencies = [
#     "torchvision",
#     "huggingface_hub",
#     "timm",
#     "opencv-python",
#     "mediapipe",
#     "timm",
# ]
# ///

import os
import sys

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from huggingface_hub import hf_hub_download
import cv2

from utils import align_crop
from timmfrv2 import TimmFRWrapperV2, model_configs

model_name = sys.argv[1]
image_1 = sys.argv[2]
image_2 = sys.argv[3]


def load_and_crop(image_filename_1, image_filename_2):
    img_1 = cv2.imread(image_filename_1)
    img_2 = cv2.imread(image_filename_2)
    crop_1 = cv2.cvtColor(align_crop(img_1), cv2.COLOR_RGB2BGR)
    crop_2 = cv2.cvtColor(align_crop(img_2), cv2.COLOR_RGB2BGR)
    return crop_1, crop_2


transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Download model")
model_path = hf_hub_download(
    repo_id=model_configs[model_name]["repo"],
    filename=model_configs[model_name]["filename"],
    local_dir="models",
)
print(f"Model downloaded in {model_path}")

print("Create model")
model = TimmFRWrapperV2(model_configs[model_name]["timm_model"], batchnorm=False)
model = model_configs[model_name]["post_setup"](model)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model = model.eval()
model.to(device)

crop_a, crop_b = load_and_crop(image_1, image_2)

with torch.no_grad():
    ea = model(transform(crop_a)[None].to(device))[0][None]
    eb = model(transform(crop_b)[None].to(device))[0][None]
pct = float(F.cosine_similarity(ea, eb).item() * 100)
pct = max(0, min(100, pct))
print(f"{pct:.2f}% match")

# cv2.imshow("crop left", crop_a)
# cv2.imshow("crop right", crop_b)
# cv2.waitKey(0)
# cv2.destroyAllWindows()