cora / utils /dino_utils.py
armikaeili's picture
code added
79c5088
import torch
from torchvision import transforms
class DINOv2Processor:
def __init__(self, model_name="dinov2_vitb14", device="cpu", image_size=518):
self.model_name = model_name
self.device = device
self.image_size = image_size
self.model = self._load_model()
def _load_model(self):
model = torch.hub.load('facebookresearch/dinov2', self.model_name)
model.eval()
model.to(self.device)
return model
def _preprocess_image(self, image):
preprocess = transforms.Compose([
transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(self.image_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
return preprocess(image)
def compute_similarity(self, pil_image1, pil_image2):
img1_t = self._preprocess_image(pil_image1).unsqueeze(0).to(self.device)
img2_t = self._preprocess_image(pil_image2).unsqueeze(0).to(self.device)
with torch.no_grad():
feat1 = self.model(img1_t)
feat2 = self.model(img2_t)
feat1 = feat1 / feat1.norm(dim=1, keepdim=True)
feat2 = feat2 / feat2.norm(dim=1, keepdim=True)
similarity = (feat1 * feat2).sum(dim=1)
return similarity.item()