Spaces:
Running
Running
File size: 1,942 Bytes
726c8e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
#!/usr/bin/env -S uv run
# /// script
# requires-python = "<= 3.12"
# dependencies = [
# "torchvision",
# "huggingface_hub",
# "timm",
# "opencv-python",
# "mediapipe",
# "timm",
# ]
# ///
import os
import sys
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from huggingface_hub import hf_hub_download
import cv2
from utils import align_crop
from timmfrv2 import TimmFRWrapperV2, model_configs
model_name = sys.argv[1]
image_1 = sys.argv[2]
image_2 = sys.argv[3]
def load_and_crop(image_filename_1, image_filename_2):
img_1 = cv2.imread(image_filename_1)
img_2 = cv2.imread(image_filename_2)
crop_1 = cv2.cvtColor(align_crop(img_1), cv2.COLOR_RGB2BGR)
crop_2 = cv2.cvtColor(align_crop(img_2), cv2.COLOR_RGB2BGR)
return crop_1, crop_2
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Download model")
model_path = hf_hub_download(
repo_id=model_configs[model_name]["repo"],
filename=model_configs[model_name]["filename"],
local_dir="models",
)
print(f"Model downloaded in {model_path}")
print("Create model")
model = TimmFRWrapperV2(model_configs[model_name]["timm_model"], batchnorm=False)
model = model_configs[model_name]["post_setup"](model)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model = model.eval()
model.to(device)
crop_a, crop_b = load_and_crop(image_1, image_2)
with torch.no_grad():
ea = model(transform(crop_a)[None].to(device))[0][None]
eb = model(transform(crop_b)[None].to(device))[0][None]
pct = float(F.cosine_similarity(ea, eb).item() * 100)
pct = max(0, min(100, pct))
print(f"{pct:.2f}% match")
# cv2.imshow("crop left", crop_a)
# cv2.imshow("crop right", crop_b)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
|