File size: 1,471 Bytes
491eded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from typing import Literal
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image

from transformers import AutoModel


class DINOv2ImageEncoder(nn.Module):
    def __init__(self, model_name: Literal[
        "facebook/dinov2-with-registers-large",
        "facebook/dinov2-large"
    ]):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name, torch_dtype=torch.bfloat16)
        self.model.requires_grad_(False)
        self.model.eval()

        DINOv2_INPUT_MEAN = torch.as_tensor([0.485, 0.456, 0.406], dtype=torch.float32)[
            None, :, None, None
        ]
        DINOv2_INPUT_STD = torch.as_tensor([0.229, 0.224, 0.225], dtype=torch.float32)[
            None, :, None, None
        ]
        self.register_buffer("DINOv2_INPUT_MEAN", DINOv2_INPUT_MEAN, persistent=False)
        self.register_buffer("DINOv2_INPUT_STD", DINOv2_INPUT_STD, persistent=False)
        self.max_size = 518
        self.hidden_size = self.model.config.hidden_size

    def preprocess(self, image: torch.Tensor):
        B, C, H, W = image.shape
        assert C == 3 and H <= self.max_size and W <= self.max_size
        image = (image - self.DINOv2_INPUT_MEAN.to(image)) / self.DINOv2_INPUT_STD.to(image)
        return image
    
    def forward(self, image: torch.Tensor):
        image = self.preprocess(image)
        features = self.model(image).last_hidden_state
        return features