pooyanrg's picture
fix
b39464b
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
from torch import nn
from .until_module import PreTrainedModel
from .module_cross import CrossModel, CrossConfig
from .module_decoder import DecoderModel, DecoderConfig
from utils.module_clip import CLIP, convert_weights
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
def update_attr(target_name, target_config, target_attr_name, source_config, source_attr_name, default_value=None):
if hasattr(source_config, source_attr_name):
if default_value is None or getattr(source_config, source_attr_name) != default_value:
setattr(target_config, target_attr_name, getattr(source_config, source_attr_name))
return target_config
class CLIP4IDCPreTrainedModel(PreTrainedModel, nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, cross_config, decoder_config, *inputs, **kwargs):
super(CLIP4IDCPreTrainedModel, self).__init__(cross_config, decoder_config)
self.cross_config = cross_config
self.decoder_config = decoder_config
self.clip = None
self.cross = None
@classmethod
def from_pretrained(cls, cross_model_name, decoder_model_name, state_dict=None, cache_dir=None, type_vocab_size=2, *inputs, **kwargs):
if state_dict is None: state_dict = {}
pretrained_clip_name = "ViT-B/16"
clip_state_dict = CLIP.get_config(pretrained_clip_name=pretrained_clip_name)
for key, val in clip_state_dict.items():
new_key = "clip." + key
if new_key not in state_dict:
state_dict[new_key] = val.clone()
cross_config, _ = CrossConfig.get_config(cross_model_name, cache_dir, type_vocab_size, state_dict=None)
decoder_config, _ = DecoderConfig.get_config(decoder_model_name, cache_dir, type_vocab_size, state_dict=None)
model = cls(cross_config, decoder_config, clip_state_dict, *inputs, **kwargs)
## ===> Initialization trick [HARD CODE]
if model.linear_patch == "3d":
contain_conv2 = False
for key in state_dict.keys():
if key.find("visual.conv2.weight") > -1:
contain_conv2 = True
break
if contain_conv2 is False and hasattr(model.clip.visual, "conv2"):
cp_weight = state_dict["clip.visual.conv1.weight"].clone()
kernel_size = model.clip.visual.conv2.weight.size(2)
conv2_size = model.clip.visual.conv2.weight.size()
conv2_size = list(conv2_size)
left_conv2_size = conv2_size.copy()
right_conv2_size = conv2_size.copy()
left_conv2_size[2] = (kernel_size - 1) // 2
right_conv2_size[2] = kernel_size - 1 - left_conv2_size[2]
left_zeros, right_zeros = None, None
if left_conv2_size[2] > 0:
left_zeros = torch.zeros(*tuple(left_conv2_size), dtype=cp_weight.dtype, device=cp_weight.device)
if right_conv2_size[2] > 0:
right_zeros = torch.zeros(*tuple(right_conv2_size), dtype=cp_weight.dtype, device=cp_weight.device)
cat_list = []
if left_zeros != None: cat_list.append(left_zeros)
cat_list.append(cp_weight.unsqueeze(2))
if right_zeros != None: cat_list.append(right_zeros)
cp_weight = torch.cat(cat_list, dim=2)
state_dict["clip.visual.conv2.weight"] = cp_weight
## <=== End of initialization trick
if state_dict is not None:
model = cls.init_preweight(model, state_dict)
return model
class CLIP4IDC(CLIP4IDCPreTrainedModel):
def __init__(self, cross_config, decoder_config, clip_state_dict):
super(CLIP4IDC, self).__init__(cross_config, decoder_config)
self.ignore_video_index = -1
# assert self.task_config.max_words <= cross_config.max_position_embeddings
# CLIP Encoders: From OpenAI: CLIP [https://github.com/openai/CLIP] ===>
vit = "visual.proj" in clip_state_dict
assert vit
if vit:
vision_width = clip_state_dict["visual.conv1.weight"].shape[0]
vision_layers = len(
[k for k in clip_state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
vision_patch_size = clip_state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((clip_state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_resolution = vision_patch_size * grid_size
else:
counts: list = [len(set(k.split(".")[2] for k in clip_state_dict if k.startswith(f"visual.layer{b}"))) for b in
[1, 2, 3, 4]]
vision_layers = tuple(counts)
vision_width = clip_state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((clip_state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_patch_size = None
assert output_width ** 2 + 1 == clip_state_dict["visual.attnpool.positional_embedding"].shape[0]
image_resolution = output_width * 32
embed_dim = clip_state_dict["text_projection"].shape[1]
context_length = clip_state_dict["positional_embedding"].shape[0]
vocab_size = clip_state_dict["token_embedding.weight"].shape[0]
transformer_width = clip_state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in clip_state_dict if k.startswith(f"transformer.resblocks")))
self.linear_patch = '2d'
# use .float() to avoid overflow/underflow from fp16 weight. https://github.com/openai/CLIP/issues/40
cut_top_layer = 0
self.clip = CLIP(
embed_dim,
image_resolution, vision_layers-cut_top_layer, vision_width, vision_patch_size,
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers-cut_top_layer,
linear_patch=self.linear_patch, intra_layers=9
).float()
bert_word_embeddings_weight = self.clip.token_embedding.weight
bert_position_embeddings_weight = self.clip.positional_embedding
for key in ["input_resolution", "context_length", "vocab_size"]:
if key in clip_state_dict:
del clip_state_dict[key]
convert_weights(self.clip)
# <=== End of CLIP Encoders
self.decoder = DecoderModel(decoder_config, bert_word_embeddings_weight, bert_position_embeddings_weight)
self.apply(self.init_weights)
def get_visual_output(self, video, visual_mask, left_gt_map, right_gt_map, shaped=False, video_frame=-1):
bs_pair = visual_mask.size(0)
visual_hidden, visual_output, left_map, right_map = self.clip.encode_image(video, left_gt_map, right_gt_map, video_frame=video_frame, return_hidden=True)
visual_hidden = visual_hidden.float()
visual_output = visual_output.float()
visual_hidden = visual_hidden.view(bs_pair, -1, visual_hidden.size(-1))
left_map = left_map.float()
right_map = right_map.float()
return visual_hidden, visual_output, left_map, right_map
def get_sequence_visual_output(self, video, visual_mask, left_gt_map, right_gt_map, shaped=False, video_frame=-1):
if shaped is False:
visual_mask = visual_mask.view(-1, visual_mask.shape[-1])
video = torch.as_tensor(video).float()
b, pair, channel, h, w = video.shape
video = video.view(b * pair, channel, h, w)
video_frame = pair
_, visual_hidden, left_map, right_map = self.get_visual_output(video, visual_mask, left_gt_map, right_gt_map, shaped=True, video_frame=video_frame)
return visual_hidden, left_map, right_map
def _get_decoder_score(self, visual_output, visual_mask, input_caption_ids, decoder_mask):
res_tuples = ()
decoder_scores = self.decoder(input_caption_ids, encoder_outs=visual_output, answer_mask=decoder_mask, encoder_mask=visual_mask)
return decoder_scores, res_tuples
def decoder_caption(self, visual_output, visual_mask, input_caption_ids, decoder_mask, get_logits=False):
decoder_scores, _ = self._get_decoder_score(visual_output, visual_mask,
input_caption_ids, decoder_mask)
if get_logits:
return decoder_scores
_, decoder_scores_result = torch.max(decoder_scores, -1)
return decoder_scores_result
def init_model(model_path, device):
model_state_dict = torch.load(model_path, map_location='cpu')
# Prepare model
cache_dir = ""
model = CLIP4IDC.from_pretrained("cross-base", "decoder-base", cache_dir=cache_dir, state_dict=model_state_dict)
model.to(device)
return model