File size: 9,196 Bytes
ad4721b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
from torch import nn

from .until_module import PreTrainedModel
from .module_cross import CrossModel, CrossConfig
from .module_decoder import DecoderModel, DecoderConfig

from utils.module_clip import CLIP, convert_weights
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence


def update_attr(target_name, target_config, target_attr_name, source_config, source_attr_name, default_value=None):
    if hasattr(source_config, source_attr_name):
        if default_value is None or getattr(source_config, source_attr_name) != default_value:
            setattr(target_config, target_attr_name, getattr(source_config, source_attr_name))
    return target_config

class CLIP4IDCPreTrainedModel(PreTrainedModel, nn.Module):
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
    def __init__(self, cross_config, decoder_config, *inputs, **kwargs):
        super(CLIP4IDCPreTrainedModel, self).__init__(cross_config, decoder_config)
        self.cross_config = cross_config
        self.decoder_config = decoder_config
        self.clip = None
        self.cross = None

    @classmethod
    def from_pretrained(cls, cross_model_name, decoder_model_name, state_dict=None, cache_dir=None, type_vocab_size=2, *inputs, **kwargs):


        if state_dict is None: state_dict = {}
        pretrained_clip_name = "ViT-B/16"
        clip_state_dict = CLIP.get_config(pretrained_clip_name=pretrained_clip_name)
        for key, val in clip_state_dict.items():
            new_key = "clip." + key
            if new_key not in state_dict:
                state_dict[new_key] = val.clone()

        cross_config, _ = CrossConfig.get_config(cross_model_name, cache_dir, type_vocab_size, state_dict=None)
        decoder_config, _ = DecoderConfig.get_config(decoder_model_name, cache_dir, type_vocab_size, state_dict=None)

        model = cls(cross_config, decoder_config, clip_state_dict, *inputs, **kwargs)

        ## ===> Initialization trick [HARD CODE]
        if model.linear_patch == "3d":
            contain_conv2 = False
            for key in state_dict.keys():
                if key.find("visual.conv2.weight") > -1:
                    contain_conv2 = True
                    break
            if contain_conv2 is False and hasattr(model.clip.visual, "conv2"):
                cp_weight = state_dict["clip.visual.conv1.weight"].clone()
                kernel_size = model.clip.visual.conv2.weight.size(2)
                conv2_size = model.clip.visual.conv2.weight.size()
                conv2_size = list(conv2_size)

                left_conv2_size = conv2_size.copy()
                right_conv2_size = conv2_size.copy()
                left_conv2_size[2] = (kernel_size - 1) // 2
                right_conv2_size[2] = kernel_size - 1 - left_conv2_size[2]

                left_zeros, right_zeros = None, None
                if left_conv2_size[2] > 0:
                    left_zeros = torch.zeros(*tuple(left_conv2_size), dtype=cp_weight.dtype, device=cp_weight.device)
                if right_conv2_size[2] > 0:
                    right_zeros = torch.zeros(*tuple(right_conv2_size), dtype=cp_weight.dtype, device=cp_weight.device)

                cat_list = []
                if left_zeros != None: cat_list.append(left_zeros)
                cat_list.append(cp_weight.unsqueeze(2))
                if right_zeros != None: cat_list.append(right_zeros)
                cp_weight = torch.cat(cat_list, dim=2)

                state_dict["clip.visual.conv2.weight"] = cp_weight

        ## <=== End of initialization trick

        if state_dict is not None:
            model = cls.init_preweight(model, state_dict)

        return model



class CLIP4IDC(CLIP4IDCPreTrainedModel):
    def __init__(self, cross_config, decoder_config, clip_state_dict):
        super(CLIP4IDC, self).__init__(cross_config, decoder_config)
        self.ignore_video_index = -1

        # assert self.task_config.max_words <= cross_config.max_position_embeddings

        # CLIP Encoders: From OpenAI: CLIP [https://github.com/openai/CLIP] ===>
        vit = "visual.proj" in clip_state_dict
        assert vit
        if vit:
            vision_width = clip_state_dict["visual.conv1.weight"].shape[0]
            vision_layers = len(
                [k for k in clip_state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
            vision_patch_size = clip_state_dict["visual.conv1.weight"].shape[-1]
            grid_size = round((clip_state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
            image_resolution = vision_patch_size * grid_size
        else:
            counts: list = [len(set(k.split(".")[2] for k in clip_state_dict if k.startswith(f"visual.layer{b}"))) for b in
                            [1, 2, 3, 4]]
            vision_layers = tuple(counts)
            vision_width = clip_state_dict["visual.layer1.0.conv1.weight"].shape[0]
            output_width = round((clip_state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
            vision_patch_size = None
            assert output_width ** 2 + 1 == clip_state_dict["visual.attnpool.positional_embedding"].shape[0]
            image_resolution = output_width * 32

        embed_dim = clip_state_dict["text_projection"].shape[1]
        context_length = clip_state_dict["positional_embedding"].shape[0]
        vocab_size = clip_state_dict["token_embedding.weight"].shape[0]
        transformer_width = clip_state_dict["ln_final.weight"].shape[0]
        transformer_heads = transformer_width // 64
        transformer_layers = len(set(k.split(".")[2] for k in clip_state_dict if k.startswith(f"transformer.resblocks")))

        self.linear_patch = '2d'

        # use .float() to avoid overflow/underflow from fp16 weight. https://github.com/openai/CLIP/issues/40
        cut_top_layer = 0
        self.clip = CLIP(
            embed_dim,
            image_resolution, vision_layers-cut_top_layer, vision_width, vision_patch_size,
            context_length, vocab_size, transformer_width, transformer_heads, transformer_layers-cut_top_layer,
            linear_patch=self.linear_patch, intra_layers=9
        ).float()

        bert_word_embeddings_weight = self.clip.token_embedding.weight
        bert_position_embeddings_weight = self.clip.positional_embedding

        for key in ["input_resolution", "context_length", "vocab_size"]:
            if key in clip_state_dict:
                del clip_state_dict[key]

        convert_weights(self.clip)
        # <=== End of CLIP Encoders

        self.decoder = DecoderModel(decoder_config, bert_word_embeddings_weight, bert_position_embeddings_weight)

        self.apply(self.init_weights)

    def get_visual_output(self, video, visual_mask, left_gt_map, right_gt_map, shaped=False, video_frame=-1):

        bs_pair = visual_mask.size(0)
        visual_hidden, visual_output, left_map, right_map = self.clip.encode_image(video, left_gt_map, right_gt_map, video_frame=video_frame, return_hidden=True)
        visual_hidden = visual_hidden.float()
        visual_output = visual_output.float()
        visual_hidden = visual_hidden.view(bs_pair, -1, visual_hidden.size(-1))

        left_map = left_map.float()
        right_map = right_map.float()

        return visual_hidden, visual_output, left_map, right_map

    def get_sequence_visual_output(self, video, visual_mask, left_gt_map, right_gt_map, shaped=False, video_frame=-1):
        if shaped is False:
            visual_mask = visual_mask.view(-1, visual_mask.shape[-1])
            video = torch.as_tensor(video).float()
            b, pair, channel, h, w = video.shape
            video = video.view(b * pair, channel, h, w)
            video_frame = pair

        _, visual_hidden, left_map, right_map = self.get_visual_output(video, visual_mask, left_gt_map, right_gt_map, shaped=True, video_frame=video_frame)

        return visual_hidden, left_map, right_map
    
    def _get_decoder_score(self, visual_output, visual_mask, input_caption_ids, decoder_mask):
        res_tuples = ()
        decoder_scores = self.decoder(input_caption_ids, encoder_outs=visual_output, answer_mask=decoder_mask, encoder_mask=visual_mask)

        return decoder_scores, res_tuples

    def decoder_caption(self, visual_output, visual_mask, input_caption_ids, decoder_mask, get_logits=False):

        decoder_scores, _ = self._get_decoder_score(visual_output, visual_mask,
                                                    input_caption_ids, decoder_mask)

        if get_logits:
            return decoder_scores

        _, decoder_scores_result = torch.max(decoder_scores, -1)

        return decoder_scores_result


def init_model(model_path, device):

    model_state_dict = torch.load(model_path, map_location='cpu')

    # Prepare model
    cache_dir = ""
    model = CLIP4IDC.from_pretrained("cross-base", "decoder-base", cache_dir=cache_dir, state_dict=model_state_dict)

    model.to(device)

    return model