import torch import torch.nn as nn class DINOv2Featurizer(nn.Module): def __init__(self): super().__init__() self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14').cuda() # self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg') self.model.eval() self.config = {} def get_cls_token(self, img): pass def forward(self, img, include_cls): feature_dict = self.model.forward_features(img) _, _, h, w = img.shape new_h, new_w = h // 14, w // 14 b, _, c = feature_dict["x_norm_patchtokens"].shape spatial_tokens = feature_dict["x_norm_patchtokens"].permute(0, 2, 1).reshape(b, c, new_h, new_w) if include_cls: return spatial_tokens, feature_dict["x_norm_clstoken"] else: return spatial_tokens if __name__ == "__main__": import torchvision.transforms as T from PIL import Image from shared import norm, crop_to_divisor device = "cuda" if torch.cuda.is_available() else "cpu" image = Image.open("../../samples/dog_man_1_crop.jpg") load_size = 224 # * 3 transform = T.Compose([ T.Resize(load_size, Image.BILINEAR), T.CenterCrop(load_size), T.ToTensor(), norm]) model = DINOv2Featurizer().cuda() results = model(transform(image).cuda().unsqueeze(0), include_cls=False) print(results.shape)