#!/usr/bin/env -S uv run # /// script # requires-python = "<= 3.12" # dependencies = [ # "torchvision", # "huggingface_hub", # "timm", # "opencv-python", # "mediapipe", # "timm", # ] # /// import os import sys import torch import torch.nn.functional as F import torchvision.transforms as transforms from huggingface_hub import hf_hub_download import cv2 from utils import align_crop from timmfrv2 import TimmFRWrapperV2, model_configs model_name = sys.argv[1] image_1 = sys.argv[2] image_2 = sys.argv[3] def load_and_crop(image_filename_1, image_filename_2): img_1 = cv2.imread(image_filename_1) img_2 = cv2.imread(image_filename_2) crop_1 = cv2.cvtColor(align_crop(img_1), cv2.COLOR_RGB2BGR) crop_2 = cv2.cvtColor(align_crop(img_2), cv2.COLOR_RGB2BGR) return crop_1, crop_2 transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ] ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Download model") model_path = hf_hub_download( repo_id=model_configs[model_name]["repo"], filename=model_configs[model_name]["filename"], local_dir="models", ) print(f"Model downloaded in {model_path}") print("Create model") model = TimmFRWrapperV2(model_configs[model_name]["timm_model"], batchnorm=False) model = model_configs[model_name]["post_setup"](model) model.load_state_dict(torch.load(model_path, map_location="cpu")) model = model.eval() model.to(device) crop_a, crop_b = load_and_crop(image_1, image_2) with torch.no_grad(): ea = model(transform(crop_a)[None].to(device))[0][None] eb = model(transform(crop_b)[None].to(device))[0][None] pct = float(F.cosine_similarity(ea, eb).item() * 100) pct = max(0, min(100, pct)) print(f"{pct:.2f}% match") # cv2.imshow("crop left", crop_a) # cv2.imshow("crop right", crop_b) # cv2.waitKey(0) # cv2.destroyAllWindows()