Spaces:
Running
Running
#!/usr/bin/env -S uv run | |
# /// script | |
# requires-python = "<= 3.12" | |
# dependencies = [ | |
# "torchvision", | |
# "huggingface_hub", | |
# "timm", | |
# "opencv-python", | |
# "mediapipe", | |
# "timm", | |
# ] | |
# /// | |
import os | |
import sys | |
import torch | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
from huggingface_hub import hf_hub_download | |
import cv2 | |
from utils import align_crop | |
from timmfrv2 import TimmFRWrapperV2, model_configs | |
model_name = sys.argv[1] | |
image_1 = sys.argv[2] | |
image_2 = sys.argv[3] | |
def load_and_crop(image_filename_1, image_filename_2): | |
img_1 = cv2.imread(image_filename_1) | |
img_2 = cv2.imread(image_filename_2) | |
crop_1 = cv2.cvtColor(align_crop(img_1), cv2.COLOR_RGB2BGR) | |
crop_2 = cv2.cvtColor(align_crop(img_2), cv2.COLOR_RGB2BGR) | |
return crop_1, crop_2 | |
transform = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
] | |
) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print("Download model") | |
model_path = hf_hub_download( | |
repo_id=model_configs[model_name]["repo"], | |
filename=model_configs[model_name]["filename"], | |
local_dir="models", | |
) | |
print(f"Model downloaded in {model_path}") | |
print("Create model") | |
model = TimmFRWrapperV2(model_configs[model_name]["timm_model"], batchnorm=False) | |
model = model_configs[model_name]["post_setup"](model) | |
model.load_state_dict(torch.load(model_path, map_location="cpu")) | |
model = model.eval() | |
model.to(device) | |
crop_a, crop_b = load_and_crop(image_1, image_2) | |
with torch.no_grad(): | |
ea = model(transform(crop_a)[None].to(device))[0][None] | |
eb = model(transform(crop_b)[None].to(device))[0][None] | |
pct = float(F.cosine_similarity(ea, eb).item() * 100) | |
pct = max(0, min(100, pct)) | |
print(f"{pct:.2f}% match") | |
# cv2.imshow("crop left", crop_a) | |
# cv2.imshow("crop right", crop_b) | |
# cv2.waitKey(0) | |
# cv2.destroyAllWindows() | |