mac9087 commited on
Commit
9745daf
·
verified ·
1 Parent(s): 04bf072

Create tokenizers/image.py

Browse files
Files changed (1) hide show
  1. tsr/models/tokenizers/image.py +66 -0
tsr/models/tokenizers/image.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from huggingface_hub import hf_hub_download
7
+ from transformers.models.vit.modeling_vit import ViTModel
8
+
9
+ from ...utils import BaseModule
10
+
11
+
12
+ class DINOSingleImageTokenizer(BaseModule):
13
+ @dataclass
14
+ class Config(BaseModule.Config):
15
+ pretrained_model_name_or_path: str = "facebook/dino-vitb16"
16
+ enable_gradient_checkpointing: bool = False
17
+
18
+ cfg: Config
19
+
20
+ def configure(self) -> None:
21
+ self.model: ViTModel = ViTModel(
22
+ ViTModel.config_class.from_pretrained(
23
+ hf_hub_download(
24
+ repo_id=self.cfg.pretrained_model_name_or_path,
25
+ filename="config.json",
26
+ )
27
+ )
28
+ )
29
+
30
+ if self.cfg.enable_gradient_checkpointing:
31
+ self.model.encoder.gradient_checkpointing = True
32
+
33
+ self.register_buffer(
34
+ "image_mean",
35
+ torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
36
+ persistent=False,
37
+ )
38
+ self.register_buffer(
39
+ "image_std",
40
+ torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
41
+ persistent=False,
42
+ )
43
+
44
+ def forward(self, images: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
45
+ packed = False
46
+ if images.ndim == 4:
47
+ packed = True
48
+ images = images.unsqueeze(1)
49
+
50
+ batch_size, n_input_views = images.shape[:2]
51
+ images = (images - self.image_mean) / self.image_std
52
+ out = self.model(
53
+ rearrange(images, "B N C H W -> (B N) C H W"), interpolate_pos_encoding=True
54
+ )
55
+ local_features, global_features = out.last_hidden_state, out.pooler_output
56
+ local_features = local_features.permute(0, 2, 1)
57
+ local_features = rearrange(
58
+ local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
59
+ )
60
+ if packed:
61
+ local_features = local_features.squeeze(1)
62
+
63
+ return local_features
64
+
65
+ def detokenize(self, *args, **kwargs):
66
+ raise NotImplementedError