Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torchvision import transforms | |
class DINOv2Processor: | |
def __init__(self, model_name="dinov2_vitb14", device="cpu", image_size=518): | |
self.model_name = model_name | |
self.device = device | |
self.image_size = image_size | |
self.model = self._load_model() | |
def _load_model(self): | |
model = torch.hub.load('facebookresearch/dinov2', self.model_name) | |
model.eval() | |
model.to(self.device) | |
return model | |
def _preprocess_image(self, image): | |
preprocess = transforms.Compose([ | |
transforms.Resize(self.image_size, interpolation=transforms.InterpolationMode.BICUBIC), | |
transforms.CenterCrop(self.image_size), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225] | |
), | |
]) | |
return preprocess(image) | |
def compute_similarity(self, pil_image1, pil_image2): | |
img1_t = self._preprocess_image(pil_image1).unsqueeze(0).to(self.device) | |
img2_t = self._preprocess_image(pil_image2).unsqueeze(0).to(self.device) | |
with torch.no_grad(): | |
feat1 = self.model(img1_t) | |
feat2 = self.model(img2_t) | |
feat1 = feat1 / feat1.norm(dim=1, keepdim=True) | |
feat2 = feat2 / feat2.norm(dim=1, keepdim=True) | |
similarity = (feat1 * feat2).sum(dim=1) | |
return similarity.item() | |