File size: 5,284 Bytes
32b542e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f2254819-deaf-48ba-848c-471f51ee1221",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4fe03089-ec4d-4bda-9b02-46cb320e516a",
   "metadata": {},
   "outputs": [],
   "source": [
    "origin_checkpoint_path = '/mnt/cache/zhujinguo/codes/UniPerceiver/work_dirs/deepspeed_moe/BERT_L12_H768_experiments/16task_90k_bertbase_lr1e-3_wd0.2_gc0.1_prenorm_warm10k_layerscale1e-3_uniformdp0.1_maeinit_fixedpos_torchfp16_unifieddataset_changeweight_stage2_224size/bertbase_womoe_pretrain2/89999/mp_rank_00_model_states.pt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "edc282cd-8345-4321-b0a0-3e21d64bfa35",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['module', 'buffer_names', 'optimizer', 'lr_scheduler', 'sparse_tensor_module_names', 'skipped_steps', 'global_steps', 'global_samples', 'dp_world_size', 'mp_world_size', 'ds_config', 'ds_version', 'iteration'])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "origin_checkpoint = torch.load(origin_checkpoint_path, 'cpu')\n",
    "origin_checkpoint.keys()\n",
    "# list(origin_checkpoint['module'].keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "79d9f479-3144-4791-82ba-71fec264aa29",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "201"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(list(origin_checkpoint['module'].keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3452947d-4593-4431-a772-3a8ad4882c03",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['model', 'trainer', 'amp_scaler', 'scheduler', 'iteration'])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# new_checkpoint_path = 'new_exp/model_Epoch_00160_Iter_0000159.pth'\n",
    "# new_checkpoint = torch.load(new_checkpoint_path, 'cpu')\n",
    "# new_checkpoint.keys()\n",
    "# list(new_checkpoint['model'].keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ffdcf5c5-ffd4-4379-89d7-37ce05c4c0f2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "41"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# len(list(new_checkpoint['model'].keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fec7a303-a30c-4e92-9452-b534a52d67e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "mapping_dict = {\n",
    "\n",
    "    'encoder.': 'fused_encoder.',\n",
    "    'attention.self.qkv_proj.weight': 'self_attn.in_proj_weight',\n",
    "    'attention.self.qkv_proj.bias': 'self_attn.in_proj_bias',\n",
    "    'attention.output.dense': 'self_attn.out_proj',\n",
    "    'attention_output.residual_scale': 'gamma_1',\n",
    "    'ffn.dense.': 'linear1.',\n",
    "    'ffn.dense2.': 'linear2.',\n",
    "    'ffn_output.residual_scale': 'gamma_2',\n",
    "    'LayerNormModules.0.': 'norm1.',\n",
    "    'LayerNormModules.1.': 'norm2.',\n",
    "    'predictor.': 'loss_prepare.',\n",
    "    \n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "897ff2f0-1232-4d25-9c13-7ea9568da362",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_checkpoint = { } \n",
    "\n",
    "module_checkpoint = origin_checkpoint['module']\n",
    "\n",
    "for k, v in module_checkpoint.items():\n",
    "    if k.endswith('residual_scale'):\n",
    "        v.squeeze_(1).squeeze_(0)\n",
    "    if k.startswith('visual_embed'):\n",
    "        continue\n",
    "    for origin_str, target_str in mapping_dict.items():\n",
    "        if origin_str in k:\n",
    "            k = k.replace(origin_str, target_str)\n",
    "    \n",
    "    new_checkpoint[k] = v.float()\n",
    "\n",
    "# merge type embedding in video_embed \n",
    "new_checkpoint['video_embed.embeddings.bias'] = new_checkpoint['video_embed.embeddings.bias'] + new_checkpoint['video_embed.embeddings_type.weight'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3c26719f-7451-4c0a-85c3-640c820dfe98",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "torch.save({ 'model': new_checkpoint}, '/mnt/lustre/zhujinguo/codes/Uni-Perceiver/work_dirs/pretrained_models/uni-perceiver-base-L12-H768-224size-pretrained.pth')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}