File size: 6,951 Bytes
ff495b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
# from .internvideo2_stage2 import InternVideo2_Stage2 as IV2S2
from transformers import PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig
from .config import InternVideo2Config as config
import warnings
import torch
from torch import nn
import torchvision.transforms as transforms
from torchvision.transforms import InterpolationMode
from transformers.utils import logging
warnings.filterwarnings("ignore")
from .internvideo2_clip_vision import InternVideo2
from .mobile_clip import TextTransformer, ClipTokenizer
logger = logging.get_logger(__name__)
class InternVideo2_CLIP_small(PreTrainedModel):
config_class = config
def __init__(self, config, tokenizer=None, is_pretrain=True):
super().__init__(config)
self.config = config
self.tokenizer = tokenizer
self.is_pretrain = is_pretrain
print(config)
if tokenizer is None:
self.tokenizer = ClipTokenizer(self.config.model.text_encoder)
# self.model = IV2S2(self.config).to('cpu').to(torch.float16)
self.vision_encoder = self.build_vision_encoder()
self.vision_align = nn.Sequential(
nn.LayerNorm(self.config.model.vision_encoder.clip_embed_dim),
nn.Linear(
self.config.model.vision_encoder.clip_embed_dim,
self.config.model.vision_encoder.align_dim
),
)
self.text_encoder = self.build_text_encoder(cfg=self.config.model.text_encoder['text_cfg'], projection_dim=self.config.model.text_encoder["embed_dim"])
# adopt 1 / 100. as in ViCLIP
self.temp = nn.parameter.Parameter(torch.ones([]) * config.model.temp)
self.temp_min = config.model.temp_min
if self.config.model.freeze_vision:
for name, p in self.vision_encoder.named_parameters():
if self.config.model.open_vision_clip_projector and name.startswith('clip_projector'):
logger.info(f"Unfreeze {name}")
else:
logger.info(f"Freeze {name}")
p.requires_grad = False
if self.config.model.freeze_text:
for name, p in self.text_encoder.named_parameters():
if self.config.model.open_text_projection and name.startswith('projection_layer'):
logger.info(f"Unfreeze {name}")
else:
logger.info(f"Freeze {name}")
p.requires_grad = False
img_size = self.config.model.vision_encoder.img_size
self.transform = transforms.Compose(
[
transforms.Resize(
(img_size, img_size),
interpolation=InterpolationMode.BICUBIC,
),
transforms.Lambda(lambda x: x.float().div(255.0)),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]
)
@torch.no_grad()
def clip_contrastive_temperature(self):
"""Seems only used during pre-training"""
self.temp.clamp_(min=self.temp_min)
def encode_vision(self, image, test=False):
"""encode image / videos as features.
Args:
image (torch.Tensor): The input images.
test (bool): Whether testing.
Returns: tuple.
- vision_embeds (torch.Tensor): The features of all patches. Shape: [B,C].
"""
T = image.shape[1]
use_image = True if T == 1 else False
image = image.permute(0, 2, 1, 3, 4) # [B,T,C,H,W] -> [B,C,T,H,W]
vision_embeds = self.vision_encoder(image, use_image=use_image)
vision_embeds = self.vision_align(vision_embeds)
return vision_embeds
def encode_text(self, text):
"""encode text.
Args:
text (dict): The output of huggingface's `PreTrainedTokenizer`. contains keys:
- input_ids (torch.Tensor): Token ids to be fed to a model. Shape: [B,L].
- attention_mask (torch.Tensor): The mask indicate padded tokens. Shape: [B,L]. 0 is padded token.
- other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__".
Returns: tuple.
- text_embeds (torch.Tensor): The features of all tokens. Shape: [B,C].
"""
text_embeds = self.text_encoder(text)
return text_embeds
def build_vision_encoder(self):
"""build vision encoder
Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`.
"""
vision_encoder = InternVideo2(
in_chans=self.config.model.vision_encoder.in_chans,
patch_size=self.config.model.vision_encoder.patch_size,
img_size=self.config.model.vision_encoder.img_size,
qkv_bias=self.config.model.vision_encoder.qkv_bias,
drop_path_rate=self.config.model.vision_encoder.drop_path_rate,
head_drop_path_rate=self.config.model.vision_encoder.head_drop_path_rate,
embed_dim=self.config.model.vision_encoder.embed_dim,
num_heads=self.config.model.vision_encoder.num_heads,
mlp_ratio=self.config.model.vision_encoder.mlp_ratio,
init_values=self.config.model.vision_encoder.init_values,
qk_normalization=self.config.model.vision_encoder.qk_normalization,
depth=self.config.model.vision_encoder.depth,
use_flash_attn=self.config.model.vision_encoder.use_flash_attn,
use_fused_rmsnorm=self.config.model.vision_encoder.use_fused_rmsnorm,
use_fused_mlp=self.config.model.vision_encoder.use_fused_mlp,
fused_mlp_heuristic=self.config.model.vision_encoder.fused_mlp_heuristic,
attn_pool_num_heads=self.config.model.vision_encoder.attn_pool_num_heads,
clip_embed_dim=self.config.model.vision_encoder.clip_embed_dim,
layerscale_no_force_fp32=self.config.model.vision_encoder.layerscale_no_force_fp32,
num_frames=self.config.model.vision_encoder.num_frames,
tubelet_size=self.config.model.vision_encoder.tubelet_size,
sep_pos_embed=self.config.model.vision_encoder.sep_pos_embed,
use_checkpoint=self.config.model.vision_encoder.use_checkpoint,
checkpoint_num=self.config.model.vision_encoder.checkpoint_num,
)
return vision_encoder
def build_text_encoder(self, cfg, projection_dim):
"""build text_encoder and possiblly video-to-text multimodal fusion encoder.
Returns: nn.Module. The text encoder
"""
text_encoder = TextTransformer(cfg, projection_dim)
return text_encoder
if __name__ == "__main__":
model_config = config()
model = InternVideo2Stage2VideoEncoder(model_config)
x = torch.randn(2, 3, 8, 224, 224, dtype=torch.float16).to(model_config.device)
output = model(x) |