Spaces:
Running
Running
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import torch | |
from torch import nn | |
from .until_module import PreTrainedModel | |
from .module_cross import CrossModel, CrossConfig | |
from .module_decoder import DecoderModel, DecoderConfig | |
from utils.module_clip import CLIP, convert_weights | |
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence | |
def update_attr(target_name, target_config, target_attr_name, source_config, source_attr_name, default_value=None): | |
if hasattr(source_config, source_attr_name): | |
if default_value is None or getattr(source_config, source_attr_name) != default_value: | |
setattr(target_config, target_attr_name, getattr(source_config, source_attr_name)) | |
return target_config | |
class CLIP4IDCPreTrainedModel(PreTrainedModel, nn.Module): | |
""" An abstract class to handle weights initialization and | |
a simple interface for dowloading and loading pretrained models. | |
""" | |
def __init__(self, cross_config, decoder_config, *inputs, **kwargs): | |
super(CLIP4IDCPreTrainedModel, self).__init__(cross_config, decoder_config) | |
self.cross_config = cross_config | |
self.decoder_config = decoder_config | |
self.clip = None | |
self.cross = None | |
def from_pretrained(cls, cross_model_name, decoder_model_name, state_dict=None, cache_dir=None, type_vocab_size=2, *inputs, **kwargs): | |
if state_dict is None: state_dict = {} | |
pretrained_clip_name = "ViT-B/16" | |
clip_state_dict = CLIP.get_config(pretrained_clip_name=pretrained_clip_name) | |
for key, val in clip_state_dict.items(): | |
new_key = "clip." + key | |
if new_key not in state_dict: | |
state_dict[new_key] = val.clone() | |
cross_config, _ = CrossConfig.get_config(cross_model_name, cache_dir, type_vocab_size, state_dict=None) | |
decoder_config, _ = DecoderConfig.get_config(decoder_model_name, cache_dir, type_vocab_size, state_dict=None) | |
model = cls(cross_config, decoder_config, clip_state_dict, *inputs, **kwargs) | |
## ===> Initialization trick [HARD CODE] | |
if model.linear_patch == "3d": | |
contain_conv2 = False | |
for key in state_dict.keys(): | |
if key.find("visual.conv2.weight") > -1: | |
contain_conv2 = True | |
break | |
if contain_conv2 is False and hasattr(model.clip.visual, "conv2"): | |
cp_weight = state_dict["clip.visual.conv1.weight"].clone() | |
kernel_size = model.clip.visual.conv2.weight.size(2) | |
conv2_size = model.clip.visual.conv2.weight.size() | |
conv2_size = list(conv2_size) | |
left_conv2_size = conv2_size.copy() | |
right_conv2_size = conv2_size.copy() | |
left_conv2_size[2] = (kernel_size - 1) // 2 | |
right_conv2_size[2] = kernel_size - 1 - left_conv2_size[2] | |
left_zeros, right_zeros = None, None | |
if left_conv2_size[2] > 0: | |
left_zeros = torch.zeros(*tuple(left_conv2_size), dtype=cp_weight.dtype, device=cp_weight.device) | |
if right_conv2_size[2] > 0: | |
right_zeros = torch.zeros(*tuple(right_conv2_size), dtype=cp_weight.dtype, device=cp_weight.device) | |
cat_list = [] | |
if left_zeros != None: cat_list.append(left_zeros) | |
cat_list.append(cp_weight.unsqueeze(2)) | |
if right_zeros != None: cat_list.append(right_zeros) | |
cp_weight = torch.cat(cat_list, dim=2) | |
state_dict["clip.visual.conv2.weight"] = cp_weight | |
## <=== End of initialization trick | |
if state_dict is not None: | |
model = cls.init_preweight(model, state_dict) | |
return model | |
class CLIP4IDC(CLIP4IDCPreTrainedModel): | |
def __init__(self, cross_config, decoder_config, clip_state_dict): | |
super(CLIP4IDC, self).__init__(cross_config, decoder_config) | |
self.ignore_video_index = -1 | |
# assert self.task_config.max_words <= cross_config.max_position_embeddings | |
# CLIP Encoders: From OpenAI: CLIP [https://github.com/openai/CLIP] ===> | |
vit = "visual.proj" in clip_state_dict | |
assert vit | |
if vit: | |
vision_width = clip_state_dict["visual.conv1.weight"].shape[0] | |
vision_layers = len( | |
[k for k in clip_state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) | |
vision_patch_size = clip_state_dict["visual.conv1.weight"].shape[-1] | |
grid_size = round((clip_state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) | |
image_resolution = vision_patch_size * grid_size | |
else: | |
counts: list = [len(set(k.split(".")[2] for k in clip_state_dict if k.startswith(f"visual.layer{b}"))) for b in | |
[1, 2, 3, 4]] | |
vision_layers = tuple(counts) | |
vision_width = clip_state_dict["visual.layer1.0.conv1.weight"].shape[0] | |
output_width = round((clip_state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) | |
vision_patch_size = None | |
assert output_width ** 2 + 1 == clip_state_dict["visual.attnpool.positional_embedding"].shape[0] | |
image_resolution = output_width * 32 | |
embed_dim = clip_state_dict["text_projection"].shape[1] | |
context_length = clip_state_dict["positional_embedding"].shape[0] | |
vocab_size = clip_state_dict["token_embedding.weight"].shape[0] | |
transformer_width = clip_state_dict["ln_final.weight"].shape[0] | |
transformer_heads = transformer_width // 64 | |
transformer_layers = len(set(k.split(".")[2] for k in clip_state_dict if k.startswith(f"transformer.resblocks"))) | |
self.linear_patch = '2d' | |
# use .float() to avoid overflow/underflow from fp16 weight. https://github.com/openai/CLIP/issues/40 | |
cut_top_layer = 0 | |
self.clip = CLIP( | |
embed_dim, | |
image_resolution, vision_layers-cut_top_layer, vision_width, vision_patch_size, | |
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers-cut_top_layer, | |
linear_patch=self.linear_patch, intra_layers=9 | |
).float() | |
bert_word_embeddings_weight = self.clip.token_embedding.weight | |
bert_position_embeddings_weight = self.clip.positional_embedding | |
for key in ["input_resolution", "context_length", "vocab_size"]: | |
if key in clip_state_dict: | |
del clip_state_dict[key] | |
convert_weights(self.clip) | |
# <=== End of CLIP Encoders | |
self.decoder = DecoderModel(decoder_config, bert_word_embeddings_weight, bert_position_embeddings_weight) | |
self.apply(self.init_weights) | |
def get_visual_output(self, video, visual_mask, left_gt_map, right_gt_map, shaped=False, video_frame=-1): | |
bs_pair = visual_mask.size(0) | |
visual_hidden, visual_output, left_map, right_map = self.clip.encode_image(video, left_gt_map, right_gt_map, video_frame=video_frame, return_hidden=True) | |
visual_hidden = visual_hidden.float() | |
visual_output = visual_output.float() | |
visual_hidden = visual_hidden.view(bs_pair, -1, visual_hidden.size(-1)) | |
left_map = left_map.float() | |
right_map = right_map.float() | |
return visual_hidden, visual_output, left_map, right_map | |
def get_sequence_visual_output(self, video, visual_mask, left_gt_map, right_gt_map, shaped=False, video_frame=-1): | |
if shaped is False: | |
visual_mask = visual_mask.view(-1, visual_mask.shape[-1]) | |
video = torch.as_tensor(video).float() | |
b, pair, channel, h, w = video.shape | |
video = video.view(b * pair, channel, h, w) | |
video_frame = pair | |
_, visual_hidden, left_map, right_map = self.get_visual_output(video, visual_mask, left_gt_map, right_gt_map, shaped=True, video_frame=video_frame) | |
return visual_hidden, left_map, right_map | |
def _get_decoder_score(self, visual_output, visual_mask, input_caption_ids, decoder_mask): | |
res_tuples = () | |
decoder_scores = self.decoder(input_caption_ids, encoder_outs=visual_output, answer_mask=decoder_mask, encoder_mask=visual_mask) | |
return decoder_scores, res_tuples | |
def decoder_caption(self, visual_output, visual_mask, input_caption_ids, decoder_mask, get_logits=False): | |
decoder_scores, _ = self._get_decoder_score(visual_output, visual_mask, | |
input_caption_ids, decoder_mask) | |
if get_logits: | |
return decoder_scores | |
_, decoder_scores_result = torch.max(decoder_scores, -1) | |
return decoder_scores_result | |
def init_model(model_path, device): | |
model_state_dict = torch.load(model_path, map_location='cpu') | |
# Prepare model | |
cache_dir = "" | |
model = CLIP4IDC.from_pretrained("cross-base", "decoder-base", cache_dir=cache_dir, state_dict=model_state_dict) | |
model.to(device) | |
return model |