multimodalart HF Staff commited on
Commit
d0cbcd5
·
verified ·
1 Parent(s): c8ad832

Upload 69 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. blip3o/.DS_Store +0 -0
  2. blip3o/__init__.py +1 -0
  3. blip3o/constants.py +26 -0
  4. blip3o/conversation.py +479 -0
  5. blip3o/mm_utils.py +247 -0
  6. blip3o/model/__init__.py +3 -0
  7. blip3o/model/apply_delta.py +48 -0
  8. blip3o/model/blip3o_arch.py +415 -0
  9. blip3o/model/builder.py +54 -0
  10. blip3o/model/consolidate.py +25 -0
  11. blip3o/model/language_model/blip3o_llama.py +413 -0
  12. blip3o/model/language_model/blip3o_qwen.py +420 -0
  13. blip3o/model/lumina_nextdit2d.py +365 -0
  14. blip3o/model/make_delta.py +48 -0
  15. blip3o/model/multimodal_encoder/builder.py +63 -0
  16. blip3o/model/multimodal_encoder/clip_encoder.py +172 -0
  17. blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/__init__.py +9 -0
  18. blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  19. blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/constants.py +2 -0
  20. blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/eva_vit_model.py +571 -0
  21. blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/factory.py +528 -0
  22. blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_configs.py +57 -0
  23. blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_model.py +240 -0
  24. blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/loss.py +123 -0
  25. blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/model.py +429 -0
  26. blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/modified_resnet.py +179 -0
  27. blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/openai.py +144 -0
  28. blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/pretrained.py +314 -0
  29. blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/rope.py +131 -0
  30. blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/timm_model.py +114 -0
  31. blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/tokenizer.py +205 -0
  32. blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/transform.py +104 -0
  33. blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/transformer.py +683 -0
  34. blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/utils.py +321 -0
  35. blip3o/model/multimodal_encoder/dev_eva_clip/eva_vit.py +140 -0
  36. blip3o/model/multimodal_encoder/eva_clip/eva_clip_encoder.py +75 -0
  37. blip3o/model/multimodal_encoder/eva_clip/eva_clip_processors.py +74 -0
  38. blip3o/model/multimodal_encoder/eva_clip/eva_vit.py +762 -0
  39. blip3o/model/multimodal_encoder/eva_clip/factory.py +59 -0
  40. blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-18B.json +27 -0
  41. blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B-plus.json +27 -0
  42. blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B.json +27 -0
  43. blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-B-16.json +19 -0
  44. blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json +24 -0
  45. blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14.json +24 -0
  46. blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-B-16.json +29 -0
  47. blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14-336.json +29 -0
  48. blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14.json +29 -0
  49. blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json +28 -0
  50. blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14.json +25 -0
blip3o/.DS_Store ADDED
Binary file (6.15 kB). View file
 
blip3o/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import blip3oLlamaForCausalLM
blip3o/constants.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ # IMAGE_TOKEN_INDEX = -200
9
+
10
+ DEFAULT_IMAGE_TOKEN = "<image>"
11
+ DEFAULT_IM_START_TOKEN = "[IMG]"
12
+ DEFAULT_IM_END_TOKEN = "[/IMG]"
13
+
14
+
15
+ # IMAGE_TOKEN_IDX = 32002
16
+ # DEFAULT_IM_START_TOKEN_IDX = 32000
17
+ # DEFAULT_IM_END_TOKEN_IDX = 32001
18
+
19
+ IMAGE_TOKEN_IDX = 151667
20
+ DEFAULT_IM_START_TOKEN_IDX = 128257
21
+ DEFAULT_IM_END_TOKEN_IDX = 128258
22
+ UND_IMAGE_TOKEN_IDX = 151655
23
+ # N_QUERY = 729
24
+
25
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
26
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
blip3o/conversation.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+ import base64
5
+ from io import BytesIO
6
+ from PIL import Image
7
+
8
+
9
+ class SeparatorStyle(Enum):
10
+ """Different separator style."""
11
+ SINGLE = auto()
12
+ TWO = auto()
13
+ MPT = auto()
14
+ PLAIN = auto()
15
+ LLAMA_2 = auto()
16
+ CHATML = auto()
17
+ QWEN = auto()
18
+
19
+
20
+ @dataclasses.dataclass
21
+ class Conversation:
22
+ """A class that keeps all conversation history."""
23
+ system: str
24
+ roles: List[str]
25
+ messages: List[List[str]]
26
+ offset: int
27
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
28
+ sep: str = "###"
29
+ sep2: str = None
30
+ version: str = "Unknown"
31
+
32
+ skip_next: bool = False
33
+
34
+ def get_prompt(self):
35
+ messages = self.messages
36
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
37
+ messages = self.messages.copy()
38
+ init_role, init_msg = messages[0].copy()
39
+ init_msg = init_msg[0]
40
+ if "mmtag" in self.version:
41
+ init_msg = init_msg.replace("<image>", "").strip()
42
+ messages[0] = (init_role, init_msg)
43
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
44
+ messages.insert(1, (self.roles[1], "Received."))
45
+ elif not init_msg.startswith("<image>"):
46
+ init_msg = init_msg.replace("<image>", "").strip()
47
+ messages[0] = (init_role, "<image>\n" + init_msg)
48
+ else:
49
+ messages[0] = (init_role, init_msg)
50
+
51
+ if self.sep_style == SeparatorStyle.SINGLE:
52
+ ret = self.system + self.sep
53
+ for role, message in messages:
54
+ if message:
55
+ if type(message) is tuple:
56
+ message, _, _ = message
57
+ ret += role + ": " + message + self.sep
58
+ else:
59
+ ret += role + ":"
60
+
61
+ elif self.sep_style == SeparatorStyle.TWO:
62
+ seps = [self.sep, self.sep2]
63
+ ret = self.system + seps[0]
64
+ for i, (role, message) in enumerate(messages):
65
+ if message:
66
+ if type(message) is tuple:
67
+ message, _, _ = message
68
+ ret += role + ": " + message + seps[i % 2]
69
+ else:
70
+ ret += role + ":"
71
+
72
+ elif self.sep_style == SeparatorStyle.CHATML:
73
+ ret = "" if self.system == "" else self.system + self.sep + "\n"
74
+ for role, message in messages:
75
+ if message:
76
+ if type(message) is tuple:
77
+ message, images, _ = message
78
+ message = "<image>" * len(images) + message
79
+ ret += role + "\n" + message + self.sep + "\n"
80
+ else:
81
+ ret += role + "\n"
82
+ return ret
83
+
84
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
85
+ if self.tokenizer is None:
86
+ raise ValueError("Llama 3 tokenizer is not available. Make sure you have the necessary permissions.")
87
+ chat_template_messages = [{"role": "system", "content": self.system}]
88
+ for role, message in messages:
89
+ if message:
90
+ if type(message) is tuple:
91
+ message, images = message
92
+ message = "<image>" * len(images) + message
93
+ chat_template_messages.append({"role": role, "content": message})
94
+
95
+ # print(chat_template_messages)
96
+ return self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False, add_generation_prompt=True)
97
+ # ret = "" if self.system == "" else self.system + self.sep + "\n"
98
+ # for role, message in messages:
99
+ # if message:
100
+ # if type(message) is tuple:
101
+ # message, images = message
102
+ # message = "<image>" * len(images) + message
103
+ # ret += role + "\n" + message + self.sep + "\n"
104
+ # else:
105
+ # ret += role + "\n"
106
+ # return ret
107
+
108
+ elif self.sep_style == SeparatorStyle.MPT:
109
+ ret = self.system + self.sep
110
+ for role, message in messages:
111
+ if message:
112
+ if type(message) is tuple:
113
+ message, _, _ = message
114
+ ret += role + message + self.sep
115
+ else:
116
+ ret += role
117
+
118
+ elif self.sep_style == SeparatorStyle.GEMMA:
119
+ ret = ""
120
+ for i, (role, message) in enumerate(messages):
121
+ assert role == self.roles[i % 2], "Conversation should alternate user/assistant/user/assistant/..."
122
+ if message:
123
+ if type(message) is tuple:
124
+ message, _, _ = message
125
+ ret += role + message + self.sep
126
+ else:
127
+ ret += role
128
+
129
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
130
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
131
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
132
+ ret = ""
133
+
134
+ for i, (role, message) in enumerate(messages):
135
+ if i == 0:
136
+ assert message, "first message should not be none"
137
+ assert role == self.roles[0], "first message should come from user"
138
+ if message:
139
+ if type(message) is tuple:
140
+ message, _, _ = message
141
+ if i == 0:
142
+ message = wrap_sys(self.system) + message
143
+ if i % 2 == 0:
144
+ message = wrap_inst(message)
145
+ ret += self.sep + message
146
+ else:
147
+ ret += " " + message + " " + self.sep2
148
+ else:
149
+ ret += ""
150
+ ret = ret.lstrip(self.sep)
151
+
152
+ elif self.sep_style == SeparatorStyle.PLAIN:
153
+ seps = [self.sep, self.sep2]
154
+ ret = self.system
155
+ for i, (role, message) in enumerate(messages):
156
+ if message:
157
+ if type(message) is tuple:
158
+ message, _, _ = message
159
+ ret += message + seps[i % 2]
160
+ else:
161
+ ret += ""
162
+ else:
163
+ raise ValueError(f"Invalid style: {self.sep_style}")
164
+
165
+ return ret
166
+
167
+ def append_message(self, role, message):
168
+ self.messages.append([role, message])
169
+
170
+ def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
171
+ if image_process_mode == "Pad":
172
+ def expand2square(pil_img, background_color=(122, 116, 104)):
173
+ width, height = pil_img.size
174
+ if width == height:
175
+ return pil_img
176
+ elif width > height:
177
+ result = Image.new(pil_img.mode, (width, width), background_color)
178
+ result.paste(pil_img, (0, (width - height) // 2))
179
+ return result
180
+ else:
181
+ result = Image.new(pil_img.mode, (height, height), background_color)
182
+ result.paste(pil_img, ((height - width) // 2, 0))
183
+ return result
184
+
185
+ image = expand2square(image)
186
+ elif image_process_mode in ["Default", "Crop"]:
187
+ pass
188
+ elif image_process_mode == "Resize":
189
+ image = image.resize((336, 336))
190
+ else:
191
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
192
+ if max(image.size) > max_len:
193
+ max_hw, min_hw = max(image.size), min(image.size)
194
+ aspect_ratio = max_hw / min_hw
195
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
196
+ longest_edge = int(shortest_edge * aspect_ratio)
197
+ W, H = image.size
198
+ if H > W:
199
+ H, W = longest_edge, shortest_edge
200
+ else:
201
+ H, W = shortest_edge, longest_edge
202
+ image = image.resize((W, H))
203
+ if return_pil:
204
+ return image
205
+ else:
206
+ buffered = BytesIO()
207
+ image.save(buffered, format=image_format)
208
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
209
+ return img_b64_str
210
+
211
+ def get_images(self, return_pil=False):
212
+ images = []
213
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
214
+ if i % 2 == 0:
215
+ if type(msg) is tuple:
216
+ msg, image, image_process_mode = msg
217
+ image = self.process_image(image, image_process_mode, return_pil=return_pil)
218
+ images.append(image)
219
+ return images
220
+
221
+ def to_gradio_chatbot(self):
222
+ ret = []
223
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
224
+ if i % 2 == 0:
225
+ if type(msg) is tuple:
226
+ msg, image, image_process_mode = msg
227
+ img_b64_str = self.process_image(
228
+ image, "Default", return_pil=False,
229
+ image_format='JPEG')
230
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
231
+ msg = img_str + msg.replace('<image>', '').strip()
232
+ ret.append([msg, None])
233
+ else:
234
+ ret.append([msg, None])
235
+ else:
236
+ ret[-1][-1] = msg
237
+ return ret
238
+
239
+ def copy(self):
240
+ return Conversation(
241
+ system=self.system,
242
+ roles=self.roles,
243
+ messages=[[x, y] for x, y in self.messages],
244
+ offset=self.offset,
245
+ sep_style=self.sep_style,
246
+ sep=self.sep,
247
+ sep2=self.sep2,
248
+ version=self.version)
249
+
250
+ def dict(self):
251
+ if len(self.get_images()) > 0:
252
+ return {
253
+ "system": self.system,
254
+ "roles": self.roles,
255
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
256
+ "offset": self.offset,
257
+ "sep": self.sep,
258
+ "sep2": self.sep2,
259
+ }
260
+ return {
261
+ "system": self.system,
262
+ "roles": self.roles,
263
+ "messages": self.messages,
264
+ "offset": self.offset,
265
+ "sep": self.sep,
266
+ "sep2": self.sep2,
267
+ }
268
+
269
+
270
+ conv_vicuna_v0 = Conversation(
271
+ system="A chat between a curious human and an artificial intelligence assistant. "
272
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
273
+ roles=("Human", "Assistant"),
274
+ messages=(
275
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
276
+ ("Assistant",
277
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
278
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
279
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
280
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
281
+ "renewable and non-renewable energy sources:\n"
282
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
283
+ "energy sources are finite and will eventually run out.\n"
284
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
285
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
286
+ "and other negative effects.\n"
287
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
288
+ "have lower operational costs than non-renewable sources.\n"
289
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
290
+ "locations than non-renewable sources.\n"
291
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
292
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
293
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
294
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
295
+ ),
296
+ offset=2,
297
+ sep_style=SeparatorStyle.SINGLE,
298
+ sep="###",
299
+ )
300
+
301
+ conv_vicuna_v1 = Conversation(
302
+ system="A chat between a curious user and an artificial intelligence assistant. "
303
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
304
+ roles=("USER", "ASSISTANT"),
305
+ version="v1",
306
+ messages=(),
307
+ offset=0,
308
+ sep_style=SeparatorStyle.TWO,
309
+ sep=" ",
310
+ sep2="</s>",
311
+ )
312
+
313
+ conv_llama_2 = Conversation(
314
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
315
+
316
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
317
+ roles=("USER", "ASSISTANT"),
318
+ version="llama_v2",
319
+ messages=(),
320
+ offset=0,
321
+ sep_style=SeparatorStyle.LLAMA_2,
322
+ sep="<s>",
323
+ sep2="</s>",
324
+ )
325
+
326
+
327
+ conv_blip3o_llama_2 = Conversation(
328
+ system="You are a helpful language and vision assistant. "
329
+ "You are able to understand the visual content that the user provides, "
330
+ "and assist the user with a variety of tasks using natural language.",
331
+ roles=("USER", "ASSISTANT"),
332
+ version="llama_v2",
333
+ messages=(),
334
+ offset=0,
335
+ sep_style=SeparatorStyle.LLAMA_2,
336
+ sep="<s>",
337
+ sep2="</s>",
338
+ )
339
+
340
+ conv_mpt = Conversation(
341
+ system="""<|im_start|>system
342
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
343
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
344
+ version="mpt",
345
+ messages=(),
346
+ offset=0,
347
+ sep_style=SeparatorStyle.MPT,
348
+ sep="<|im_end|>",
349
+ )
350
+
351
+ conv_blip3o_plain = Conversation(
352
+ system="",
353
+ roles=("", ""),
354
+ messages=(
355
+ ),
356
+ offset=0,
357
+ sep_style=SeparatorStyle.PLAIN,
358
+ sep="\n",
359
+ )
360
+
361
+ conv_blip3o_v0 = Conversation(
362
+ system="A chat between a curious human and an artificial intelligence assistant. "
363
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
364
+ roles=("Human", "Assistant"),
365
+ messages=(
366
+ ),
367
+ offset=0,
368
+ sep_style=SeparatorStyle.SINGLE,
369
+ sep="###",
370
+ )
371
+
372
+ conv_blip3o_v0_mmtag = Conversation(
373
+ system="A chat between a curious user and an artificial intelligence assistant. "
374
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
375
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
376
+ roles=("Human", "Assistant"),
377
+ messages=(
378
+ ),
379
+ offset=0,
380
+ sep_style=SeparatorStyle.SINGLE,
381
+ sep="###",
382
+ version="v0_mmtag",
383
+ )
384
+
385
+ conv_blip3o_v1 = Conversation(
386
+ system="A chat between a curious human and an artificial intelligence assistant. "
387
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
388
+ roles=("USER", "ASSISTANT"),
389
+ version="v1",
390
+ messages=(),
391
+ offset=0,
392
+ sep_style=SeparatorStyle.TWO,
393
+ sep=" ",
394
+ sep2="</s>",
395
+ )
396
+
397
+ conv_blip3o_v1_mmtag = Conversation(
398
+ system="A chat between a curious user and an artificial intelligence assistant. "
399
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
400
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
401
+ roles=("USER", "ASSISTANT"),
402
+ messages=(),
403
+ offset=0,
404
+ sep_style=SeparatorStyle.TWO,
405
+ sep=" ",
406
+ sep2="</s>",
407
+ version="v1_mmtag",
408
+ )
409
+
410
+ conv_mistral_instruct = Conversation(
411
+ system="",
412
+ roles=("USER", "ASSISTANT"),
413
+ version="llama_v2",
414
+ messages=(),
415
+ offset=0,
416
+ sep_style=SeparatorStyle.LLAMA_2,
417
+ sep="",
418
+ sep2="</s>",
419
+ )
420
+
421
+ conv_chatml_direct = Conversation(
422
+ system="""<|im_start|>system
423
+ Answer the questions.""",
424
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
425
+ version="mpt",
426
+ messages=(),
427
+ offset=0,
428
+ sep_style=SeparatorStyle.MPT,
429
+ sep="<|im_end|>",
430
+ )
431
+
432
+ conv_llama3 = Conversation(
433
+ system="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.""",
434
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
435
+ version="llama3",
436
+ messages=(),
437
+ offset=0,
438
+ sep_style=SeparatorStyle.MPT,
439
+ sep="<|eot_id|>",
440
+ )
441
+
442
+ conv_qwen = Conversation(
443
+ system="""<|im_start|>system
444
+ You are a helpful assistant.""",
445
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
446
+ version="qwen",
447
+ messages=[],
448
+ offset=0,
449
+ sep_style=SeparatorStyle.CHATML,
450
+ sep="<|im_end|>",
451
+ )
452
+
453
+
454
+ default_conversation = conv_llama3
455
+ conv_templates = {
456
+ "default": conv_vicuna_v0,
457
+ "v0": conv_vicuna_v0,
458
+ "v1": conv_vicuna_v1,
459
+ "vicuna_v1": conv_vicuna_v1,
460
+ "llama_2": conv_llama_2,
461
+ "mistral_instruct": conv_mistral_instruct,
462
+ "chatml_direct": conv_chatml_direct,
463
+ "mistral_direct": conv_chatml_direct,
464
+
465
+ "plain": conv_blip3o_plain,
466
+ "v0_plain": conv_blip3o_plain,
467
+ "blip3o_v0": conv_blip3o_v0,
468
+ "v0_mmtag": conv_blip3o_v0_mmtag,
469
+ "blip3o_v1": conv_blip3o_v1,
470
+ "v1_mmtag": conv_blip3o_v1_mmtag,
471
+ "blip3o_llama_2": conv_blip3o_llama_2,
472
+ "llama3": conv_llama3,
473
+ "qwen": conv_qwen,
474
+
475
+ "mpt": conv_mpt,
476
+ }
477
+
478
+ if __name__ == "__main__":
479
+ print(default_conversation.get_prompt())
blip3o/mm_utils.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+ import torch
5
+ import math
6
+ import ast
7
+
8
+ from transformers import StoppingCriteria
9
+ from blip3o.constants import IMAGE_TOKEN_IDX
10
+
11
+
12
+ def select_best_resolution(original_size, possible_resolutions):
13
+ """
14
+ Selects the best resolution from a list of possible resolutions based on the original size.
15
+
16
+ Args:
17
+ original_size (tuple): The original size of the image in the format (width, height).
18
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
19
+
20
+ Returns:
21
+ tuple: The best fit resolution in the format (width, height).
22
+ """
23
+ original_width, original_height = original_size
24
+ best_fit = None
25
+ max_effective_resolution = 0
26
+ min_wasted_resolution = float('inf')
27
+
28
+ for width, height in possible_resolutions:
29
+ scale = min(width / original_width, height / original_height)
30
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
31
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
32
+ wasted_resolution = (width * height) - effective_resolution
33
+
34
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
35
+ max_effective_resolution = effective_resolution
36
+ min_wasted_resolution = wasted_resolution
37
+ best_fit = (width, height)
38
+
39
+ return best_fit
40
+
41
+
42
+ def resize_and_pad_image(image, target_resolution):
43
+ """
44
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
45
+
46
+ Args:
47
+ image (PIL.Image.Image): The input image.
48
+ target_resolution (tuple): The target resolution (width, height) of the image.
49
+
50
+ Returns:
51
+ PIL.Image.Image: The resized and padded image.
52
+ """
53
+ original_width, original_height = image.size
54
+ target_width, target_height = target_resolution
55
+
56
+ scale_w = target_width / original_width
57
+ scale_h = target_height / original_height
58
+
59
+ if scale_w < scale_h:
60
+ new_width = target_width
61
+ new_height = min(math.ceil(original_height * scale_w), target_height)
62
+ else:
63
+ new_height = target_height
64
+ new_width = min(math.ceil(original_width * scale_h), target_width)
65
+
66
+ # Resize the image
67
+ resized_image = image.resize((new_width, new_height))
68
+
69
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
70
+ paste_x = (target_width - new_width) // 2
71
+ paste_y = (target_height - new_height) // 2
72
+ new_image.paste(resized_image, (paste_x, paste_y))
73
+
74
+ return new_image
75
+
76
+
77
+ def divide_to_patches(image, patch_size):
78
+ """
79
+ Divides an image into patches of a specified size.
80
+
81
+ Args:
82
+ image (PIL.Image.Image): The input image.
83
+ patch_size (int): The size of each patch.
84
+
85
+ Returns:
86
+ list: A list of PIL.Image.Image objects representing the patches.
87
+ """
88
+ patches = []
89
+ width, height = image.size
90
+ for i in range(0, height, patch_size):
91
+ for j in range(0, width, patch_size):
92
+ box = (j, i, j + patch_size, i + patch_size)
93
+ patch = image.crop(box)
94
+ patches.append(patch)
95
+
96
+ return patches
97
+
98
+
99
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
100
+ """
101
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
102
+
103
+ Args:
104
+ image_size (tuple): The size of the input image in the format (width, height).
105
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
106
+ patch_size (int): The size of each image patch.
107
+
108
+ Returns:
109
+ tuple: The shape of the image patch grid in the format (width, height).
110
+ """
111
+ if type(grid_pinpoints) is list:
112
+ possible_resolutions = grid_pinpoints
113
+ else:
114
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
115
+ width, height = select_best_resolution(image_size, possible_resolutions)
116
+ return width // patch_size, height // patch_size
117
+
118
+
119
+ def process_anyres_image(image, processor, grid_pinpoints):
120
+ """
121
+ Process an image with variable resolutions.
122
+
123
+ Args:
124
+ image (PIL.Image.Image): The input image to be processed.
125
+ processor: The image processor object.
126
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
127
+
128
+ Returns:
129
+ torch.Tensor: A tensor containing the processed image patches.
130
+ """
131
+ if type(grid_pinpoints) is list:
132
+ possible_resolutions = grid_pinpoints
133
+ else:
134
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
135
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
136
+ image_padded = resize_and_pad_image(image, best_resolution)
137
+
138
+ patches = divide_to_patches(image_padded, processor.crop_size['height'])
139
+
140
+ image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
141
+
142
+ image_patches = [image_original_resize] + patches
143
+ image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
144
+ for image_patch in image_patches]
145
+ return torch.stack(image_patches, dim=0)
146
+
147
+
148
+ def load_image_from_base64(image):
149
+ return Image.open(BytesIO(base64.b64decode(image)))
150
+
151
+
152
+ def expand2square(pil_img, background_color):
153
+ width, height = pil_img.size
154
+ if width == height:
155
+ return pil_img
156
+ elif width > height:
157
+ result = Image.new(pil_img.mode, (width, width), background_color)
158
+ result.paste(pil_img, (0, (width - height) // 2))
159
+ return result
160
+ else:
161
+ result = Image.new(pil_img.mode, (height, height), background_color)
162
+ result.paste(pil_img, ((height - width) // 2, 0))
163
+ return result
164
+
165
+
166
+ def process_images(images, image_processor, model_cfg):
167
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
168
+ new_images = []
169
+ if image_aspect_ratio == 'pad':
170
+ for image in images:
171
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
172
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
173
+ new_images.append(image)
174
+ elif image_aspect_ratio == "anyres":
175
+ for image in images:
176
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
177
+ new_images.append(image)
178
+ else:
179
+ return image_processor(images, return_tensors='pt')['pixel_values']
180
+ if all(x.shape == new_images[0].shape for x in new_images):
181
+ new_images = torch.stack(new_images, dim=0)
182
+ return new_images
183
+
184
+
185
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_IDX, return_tensors=None):
186
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
187
+
188
+ def insert_separator(X, sep):
189
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
190
+
191
+ input_ids = []
192
+ offset = 0
193
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
194
+ offset = 1
195
+ input_ids.append(prompt_chunks[0][0])
196
+
197
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
198
+ input_ids.extend(x[offset:])
199
+
200
+ if return_tensors is not None:
201
+ if return_tensors == 'pt':
202
+ return torch.tensor(input_ids, dtype=torch.long)
203
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
204
+ return input_ids
205
+
206
+
207
+ def get_model_name_from_path(model_path):
208
+ model_path = model_path.strip("/")
209
+ model_paths = model_path.split("/")
210
+ if model_paths[-1].startswith('checkpoint-'):
211
+ return model_paths[-2] + "_" + model_paths[-1]
212
+ else:
213
+ return model_paths[-1]
214
+
215
+ class KeywordsStoppingCriteria(StoppingCriteria):
216
+ def __init__(self, keywords, tokenizer, input_ids):
217
+ self.keywords = keywords
218
+ self.keyword_ids = []
219
+ self.max_keyword_len = 0
220
+ for keyword in keywords:
221
+ cur_keyword_ids = tokenizer(keyword).input_ids
222
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
223
+ cur_keyword_ids = cur_keyword_ids[1:]
224
+ if len(cur_keyword_ids) > self.max_keyword_len:
225
+ self.max_keyword_len = len(cur_keyword_ids)
226
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
227
+ self.tokenizer = tokenizer
228
+ self.start_len = input_ids.shape[1]
229
+
230
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
231
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
232
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
233
+ for keyword_id in self.keyword_ids:
234
+ truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
235
+ if torch.equal(truncated_output_ids, keyword_id):
236
+ return True
237
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
238
+ for keyword in self.keywords:
239
+ if keyword in outputs:
240
+ return True
241
+ return False
242
+
243
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
244
+ outputs = []
245
+ for i in range(output_ids.shape[0]):
246
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
247
+ return all(outputs)
blip3o/model/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .language_model.blip3o_llama import blip3oLlamaForCausalLM, blip3oConfig
2
+ from .language_model.blip3o_qwen import blip3oQwenForCausalLM, blip3oQwenConfig
3
+
blip3o/model/apply_delta.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from blip3o import blip3oLlamaForCausalLM
11
+
12
+
13
+ def apply_delta(base_model_path, target_model_path, delta_path):
14
+ print("Loading base model")
15
+ base = AutoModelForCausalLM.from_pretrained(
16
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+
18
+ print("Loading delta")
19
+ delta = blip3oLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21
+
22
+ print("Applying delta")
23
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24
+ if name not in base.state_dict():
25
+ assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26
+ continue
27
+ if param.data.shape == base.state_dict()[name].shape:
28
+ param.data += base.state_dict()[name]
29
+ else:
30
+ assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31
+ f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
32
+ bparam = base.state_dict()[name]
33
+ param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34
+
35
+ print("Saving target model")
36
+ delta.save_pretrained(target_model_path)
37
+ delta_tokenizer.save_pretrained(target_model_path)
38
+
39
+
40
+ if __name__ == "__main__":
41
+ parser = argparse.ArgumentParser()
42
+ parser.add_argument("--base-model-path", type=str, required=True)
43
+ parser.add_argument("--target-model-path", type=str, required=True)
44
+ parser.add_argument("--delta-path", type=str, required=True)
45
+
46
+ args = parser.parse_args()
47
+
48
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
blip3o/model/blip3o_arch.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from .multimodal_encoder.builder import build_vision_tower, build_gen_vision_tower, build_dit
8
+ from .multimodal_projector.builder import build_vision_projector, build_down_projector, build_gen_vision_projector
9
+
10
+ from blip3o.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_IDX, DEFAULT_IM_START_TOKEN_IDX, DEFAULT_IM_END_TOKEN_IDX, UND_IMAGE_TOKEN_IDX
11
+
12
+
13
+
14
+ class blip3oMetaModel:
15
+
16
+ def __init__(self, config):
17
+ super(blip3oMetaModel, self).__init__(config)
18
+
19
+ if hasattr(config, "mm_vision_tower"):
20
+ # self.vision_tower = build_vision_tower(config, delay_load=True)
21
+ # self.mm_projector = build_vision_projector(config)
22
+ self.down_projector = build_down_projector(config)
23
+
24
+ if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
25
+ self.image_newline = nn.Parameter(
26
+ torch.empty(config.hidden_size, dtype=self.dtype)
27
+ )
28
+
29
+
30
+ if hasattr(config, "gen_vision_tower"):
31
+ self.gen_vision_tower = build_gen_vision_tower(config, delay_load=True)
32
+ # self.gen_projector = build_gen_vision_projector(config)
33
+ self.latent_queries = nn.Parameter(torch.randn(1, config.n_query, config.hidden_size))
34
+ print(f" latent query size {self.latent_queries.shape}")
35
+
36
+ if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
37
+ self.image_newline = nn.Parameter(
38
+ torch.empty(config.hidden_size, dtype=self.dtype)
39
+ )
40
+
41
+ self.dit, self.vae, self.noise_scheduler = build_dit(config)
42
+
43
+
44
+ # def get_vision_tower(self):
45
+ # vision_tower = getattr(self, 'vision_tower', None)
46
+ # if type(vision_tower) is list:
47
+ # vision_tower = vision_tower[0]
48
+ # return vision_tower
49
+
50
+
51
+ def get_gen_vision_tower(self):
52
+ gen_vision_tower = getattr(self, 'gen_vision_tower', None)
53
+ if type(gen_vision_tower) is list:
54
+ gen_vision_tower = gen_vision_tower[0]
55
+ return gen_vision_tower
56
+
57
+
58
+ def initialize_vision_modules(self, model_args, fsdp=None):
59
+ gen_vision_tower = model_args.gen_vision_tower
60
+
61
+ mm_vision_select_layer = model_args.mm_vision_select_layer
62
+ mm_vision_select_feature = model_args.mm_vision_select_feature
63
+
64
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
65
+ pretrain_gen_mlp_adapter = model_args.pretrain_gen_mlp_adapter
66
+
67
+ mm_patch_merge_type = model_args.mm_patch_merge_type
68
+
69
+ self.config.gen_vision_tower = gen_vision_tower
70
+ self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "")
71
+
72
+
73
+
74
+ if getattr(self, 'dit', None) is None:
75
+ print("random initiation the DiT !!!")
76
+ self.dit, self.vae, self.noise_scheduler = build_dit(model_args)
77
+ else:
78
+ print("DiT load from checkpoint!!!")
79
+ for p in self.dit.parameters():
80
+ p.requires_grad = True
81
+
82
+
83
+ if self.get_gen_vision_tower() is None:
84
+ gen_vision_tower = build_gen_vision_tower(model_args)
85
+
86
+ if fsdp is not None and len(fsdp) > 0:
87
+ self.gen_vision_tower = [gen_vision_tower]
88
+ else:
89
+ self.gen_vision_tower = gen_vision_tower
90
+ else:
91
+ if fsdp is not None and len(fsdp) > 0:
92
+ gen_vision_tower = self.gen_vision_tower[0]
93
+ else:
94
+ gen_vision_tower = self.gen_vision_tower
95
+ gen_vision_tower.load_model()
96
+
97
+
98
+ self.config.use_mm_proj = True
99
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
100
+ # self.config.gen_projector_type = getattr(model_args, 'gen_projector_type', 'linear')
101
+
102
+
103
+ self.config.gen_hidden_size = gen_vision_tower.hidden_size
104
+
105
+ self.config.mm_vision_select_layer = mm_vision_select_layer
106
+ self.config.mm_vision_select_feature = mm_vision_select_feature
107
+ self.config.mm_patch_merge_type = mm_patch_merge_type
108
+ self.config.n_query = model_args.n_query
109
+ self.config.gen_pooling = model_args.gen_pooling
110
+
111
+ # if getattr(self, 'mm_projector', None) is None:
112
+ # print("random initiation the mm_project !!!")
113
+ # self.mm_projector = build_vision_projector(self.config)
114
+
115
+ # if 'unpad' in mm_patch_merge_type:
116
+ # embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
117
+ # self.image_newline = nn.Parameter(
118
+ # torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
119
+ # )
120
+ # else:
121
+ # # In case it is frozen by LoRA
122
+ # for p in self.mm_projector.parameters():
123
+ # p.requires_grad = True
124
+
125
+
126
+
127
+ if getattr(self, 'down_projector', None) is None:
128
+ print("random initiation the down_projector !!!")
129
+ self.down_projector = build_down_projector(self.config)
130
+ else:
131
+ # In case it is frozen by LoRA
132
+ for p in self.down_projector.parameters():
133
+ p.requires_grad = True
134
+
135
+ if getattr(self, 'latent_queries', None) is None:
136
+ print("random initiation the latent_queries !!!")
137
+ self.latent_queries = nn.Parameter(torch.randn(1, self.config.n_query, self.config.hidden_size))
138
+ else:
139
+ print("latent_queries load from checkpoint!!!")
140
+ self.latent_queries.requires_grad = True
141
+
142
+
143
+ if pretrain_mm_mlp_adapter is not None:
144
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
145
+ def get_w(weights, keyword):
146
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
147
+
148
+ # self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
149
+
150
+
151
+
152
+ def unpad_image(tensor, original_size):
153
+ """
154
+ Unpads a PyTorch tensor of a padded and resized image.
155
+
156
+ Args:
157
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
158
+ original_size (tuple): The original size of PIL image (width, height).
159
+
160
+ Returns:
161
+ torch.Tensor: The unpadded image tensor.
162
+ """
163
+ original_width, original_height = original_size
164
+ current_height, current_width = tensor.shape[1:]
165
+
166
+ original_aspect_ratio = original_width / original_height
167
+ current_aspect_ratio = current_width / current_height
168
+
169
+ if original_aspect_ratio > current_aspect_ratio:
170
+ scale_factor = current_width / original_width
171
+ new_height = int(original_height * scale_factor)
172
+ padding = (current_height - new_height) // 2
173
+ unpadded_tensor = tensor[:, padding:current_height - padding, :]
174
+ else:
175
+ scale_factor = current_height / original_height
176
+ new_width = int(original_width * scale_factor)
177
+ padding = (current_width - new_width) // 2
178
+ unpadded_tensor = tensor[:, :, padding:current_width - padding]
179
+
180
+ return unpadded_tensor
181
+
182
+
183
+ class blip3oMetaForCausalLM(ABC):
184
+
185
+ @abstractmethod
186
+ def get_model(self):
187
+ pass
188
+
189
+ def get_vision_tower(self):
190
+ return self.get_model().get_vision_tower()
191
+
192
+ def get_gen_vision_tower(self):
193
+ return self.get_model().get_gen_vision_tower()
194
+
195
+ def encode_image(self, images):
196
+ # breakpoint()
197
+ gen_vision_tower = self.get_gen_vision_tower()
198
+ device = gen_vision_tower.device
199
+ images = images.to(device)
200
+ prompt_image_embeds = gen_vision_tower(images)
201
+ if 'early' in self.get_gen_pooling():
202
+ prompt_image_embeds = self.pool_img(prompt_image_embeds)
203
+ num_img, _, c = prompt_image_embeds.shape
204
+ # prompt_image_embeds = prompt_image_embeds.contiguous().view(-1, c)
205
+
206
+ # ------------- compute similarity -------
207
+ all_dist = 0
208
+ count = 0
209
+ for i in range(2, prompt_image_embeds.shape[1]-1):
210
+ diff = (prompt_image_embeds[:,i,:].unsqueeze(1) - prompt_image_embeds[:,:i,:])
211
+ dist = torch.sqrt(diff.square().sum(-1)).min().item()
212
+ all_dist+=dist
213
+ count+=1
214
+ all_dist /= count
215
+ # self.dist = all_dist
216
+ # print(self.dist)
217
+
218
+ return prompt_image_embeds
219
+
220
+ def get_mm_projector(self):
221
+ return self.get_model().mm_projector
222
+
223
+ def get_gen_projector(self):
224
+ return None
225
+
226
+
227
+ def get_n_query(self):
228
+ return self.get_model().config.n_query
229
+
230
+ def get_gen_pooling(self):
231
+ return self.get_model().config.gen_pooling
232
+
233
+ def pool_img(self, image_features):
234
+ num_img, n, c = image_features.shape
235
+ gen_pooling = self.get_gen_pooling()
236
+ # n_query = self.get_n_query()
237
+ stride = int(gen_pooling.split('_')[-1])
238
+ sqrt_n = int(n**0.5)
239
+ image_features = image_features.permute(0, 2, 1).view(num_img, c, sqrt_n, sqrt_n)
240
+ image_features = F.avg_pool2d(image_features, kernel_size=(stride, stride), stride=stride)
241
+ # image_features = image_features.view(num_img, c, -1).permute(0,2,1).contiguous()
242
+ return image_features
243
+
244
+ def get_sigmas(self, timesteps, device, n_dim=4, dtype=torch.float32):
245
+ sigmas = self.get_model().noise_scheduler.sigmas.to(device=device, dtype=dtype)
246
+ schedule_timesteps = self.get_model().noise_scheduler.timesteps.to(device=device)
247
+ timesteps = timesteps.to(device)
248
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
249
+
250
+ sigma = sigmas[step_indices].flatten()
251
+ while len(sigma.shape) < n_dim:
252
+ sigma = sigma.unsqueeze(-1)
253
+ return sigma
254
+
255
+ def mask_drop(self, latents, drop_prob=0.1):
256
+ if drop_prob <= 0:
257
+ return latents
258
+ mask = torch.bernoulli(torch.zeros(latents.shape[0], device=latents.device, dtype=latents.dtype) + drop_prob)
259
+ while len(mask.shape) < len(latents.shape):
260
+ mask = mask.unsqueeze(-1)
261
+ mask = 1 - mask # need to flip 0 <-> 1
262
+ return latents * mask
263
+
264
+ def prepare_inputs_labels_for_multimodal(
265
+ self, input_ids, position_ids, attention_mask, past_key_values, labels,
266
+ gen_images, und_images, grid_thw, i_s_pos, image_sizes=None
267
+ ):
268
+ pad_ids = 128256
269
+ vision_tower = self.visual
270
+ gen_vision_tower = self.get_gen_vision_tower()
271
+ if (gen_images is None and und_images is None) or input_ids.shape[1] == 1:
272
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels, None, None, None
273
+
274
+
275
+
276
+
277
+ prompt_image_embeds = gen_vision_tower(gen_images) # TODO: check dimension
278
+
279
+ if 'early' in self.get_gen_pooling():
280
+ prompt_image_embeds = self.pool_img(prompt_image_embeds)
281
+ target_image_embeds = torch.clone(prompt_image_embeds).detach()
282
+ latent_queries = self.get_model().latent_queries.repeat(input_ids.shape[0], 1, 1)
283
+ H = latent_queries.shape[-1]
284
+ latent_queries = latent_queries.contiguous().view(-1, H)
285
+
286
+
287
+
288
+
289
+ # if not gen_images is None:
290
+ # prompt_image_embeds = gen_vision_tower(gen_images) # TODO: check dimension
291
+ # if 'early' in self.get_gen_pooling():
292
+ # prompt_image_embeds = self.pool_img(prompt_image_embeds)
293
+ # # num_img, _, c = prompt_image_embeds.shape # [batch, 729, 1152]
294
+ # # prompt_image_embeds = prompt_image_embeds.contiguous().view(-1, c)
295
+ # target_image_embeds = torch.clone(prompt_image_embeds).detach()
296
+ # # prompt_image_embeds = gen_projector(prompt_image_embeds)
297
+ # latent_queries = self.get_model().latent_queries.repeat(input_ids.shape[0], 1, 1)
298
+ # H = latent_queries.shape[-1]
299
+ # latent_queries = latent_queries.contiguous().view(-1, H)
300
+ # else:
301
+ # target_image_embeds = None
302
+ # num_img = und_images.shape[0]
303
+ # dummy = torch.zeros(num_img, 3, 448, 448 , dtype=und_images.dtype, device=und_images.device) # TODO
304
+ # temp = gen_vision_tower(dummy)[:,:729,:]
305
+ # num_img, _, c = temp.shape
306
+ # temp = temp.contiguous().view(-1, c) * 1e-20
307
+ # # temp = gen_projector(temp) * 1e-9
308
+ # latent_queries = self.get_model().latent_queries.repeat(input_ids.shape[0], 1, 1)
309
+ # H = latent_queries.shape[-1]
310
+ # latent_queries = latent_queries.contiguous().view(-1, H)
311
+
312
+
313
+ if not und_images is None:
314
+ und_image_embeds = vision_tower(und_images, grid_thw=grid_thw)
315
+ # _, c = und_image_embeds.shape
316
+ # batch_size = und_images.shape[0]
317
+ # und_image_embeds = und_image_embeds.view(batch_size, -1, c)
318
+ # und_image_embeds = und_image_embeds.contiguous().view(-1, c)
319
+ # und_image_embeds = mm_projector(und_image_embeds)
320
+
321
+ # else:
322
+ # num_img = input_ids.shape[0]
323
+ # dummy = torch.zeros(num_img, 3, 384, 384 , dtype=gen_images.dtype, device=gen_images.device) # clip (3, 336, 336)
324
+ # temp = vision_tower(dummy)
325
+ # if 'early' in self.get_gen_pooling():
326
+ # temp = temp[:,:64,:]
327
+ # num_img, _, c = temp.shape
328
+ # temp = temp.contiguous().view(-1, c)
329
+ # temp = mm_projector(temp) * 1e-20
330
+ # latent_queries += temp
331
+
332
+
333
+
334
+
335
+
336
+ image_idx = (input_ids == IMAGE_TOKEN_IDX)
337
+ und_image_idx = (input_ids == UND_IMAGE_TOKEN_IDX)
338
+ # img_indicator = torch.clone(image_idx)
339
+ output_indicator = labels != -100
340
+ input_indicator = labels == -100
341
+ # img_loss_indicator = torch.logical_and(output_indicator, image_idx)
342
+ # img_loss_indicator = torch.cat(
343
+ # [img_loss_indicator[:, 1:], img_loss_indicator[:, :1]], dim=1)
344
+
345
+ # img_indicator = torch.cat(
346
+ # [img_indicator[:, 1:], img_indicator[:, :1]], dim=1)
347
+
348
+ # if not target_image_embeds is None:
349
+ # target_image_embeds = target_image_embeds[-img_loss_indicator.sum():,:]
350
+ text_embeds = self.get_model().embed_tokens(input_ids)
351
+ # N_QUERY = self.get_n_query()
352
+ gen_img_idx = torch.logical_and(output_indicator, image_idx)
353
+
354
+ # if not target_image_embeds is None:
355
+ text_embeds = text_embeds.clone()
356
+ text_embeds[gen_img_idx] = latent_queries
357
+ # text_embeds[gen_img_idx] = prompt_image_embeds.to(text_embeds.device)[:gen_img_idx.sum(),:]
358
+ # target_image_embeds = target_image_embeds.to(text_embeds.device)[:gen_img_idx.sum(),:]
359
+
360
+ und_img_idx = torch.logical_and(input_indicator, und_image_idx)
361
+
362
+
363
+ if not und_images is None:
364
+ text_embeds[und_img_idx] = und_image_embeds.to(text_embeds.device)[:und_img_idx.sum(), :]
365
+
366
+ labels[image_idx] = -100
367
+
368
+
369
+ return None, position_ids, attention_mask, past_key_values, text_embeds, labels, target_image_embeds
370
+
371
+
372
+
373
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
374
+ if model_args.mm_use_im_patch_token:
375
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
376
+ self.resize_token_embeddings(len(tokenizer))
377
+
378
+ if model_args.mm_use_im_start_end:
379
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
380
+ self.resize_token_embeddings(len(tokenizer))
381
+
382
+ if num_new_tokens > 0:
383
+ input_embeddings = self.get_input_embeddings().weight.data
384
+ output_embeddings = self.get_output_embeddings().weight.data
385
+
386
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
387
+ dim=0, keepdim=True)
388
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
389
+ dim=0, keepdim=True)
390
+
391
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
392
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
393
+
394
+ if model_args.tune_mm_mlp_adapter:
395
+ for p in self.get_input_embeddings().parameters():
396
+ p.requires_grad = True
397
+ for p in self.get_output_embeddings().parameters():
398
+ p.requires_grad = False
399
+
400
+ if model_args.pretrain_mm_mlp_adapter:
401
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
402
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
403
+ assert num_new_tokens == 2
404
+ if input_embeddings.shape == embed_tokens_weight.shape:
405
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
406
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
407
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
408
+ else:
409
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
410
+ elif model_args.mm_use_im_patch_token:
411
+ if model_args.tune_mm_mlp_adapter:
412
+ for p in self.get_input_embeddings().parameters():
413
+ p.requires_grad = False
414
+ for p in self.get_output_embeddings().parameters():
415
+ p.requires_grad = False
blip3o/model/builder.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ import shutil
4
+
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
6
+ import torch
7
+ from blip3o.model import *
8
+ from blip3o.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9
+ from blip3o.train.train import smart_tokenizer_and_embedding_resize
10
+
11
+
12
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
13
+ kwargs = {"device_map": device_map, **kwargs}
14
+
15
+ if device != "cuda":
16
+ kwargs['device_map'] = {"": device}
17
+
18
+ if load_8bit:
19
+ kwargs['load_in_8bit'] = True
20
+ elif load_4bit:
21
+ kwargs['load_in_4bit'] = True
22
+ kwargs['quantization_config'] = BitsAndBytesConfig(
23
+ load_in_4bit=True,
24
+ bnb_4bit_compute_dtype=torch.float16,
25
+ bnb_4bit_use_double_quant=True,
26
+ bnb_4bit_quant_type='nf4'
27
+ )
28
+ else:
29
+ kwargs['torch_dtype'] = torch.float16
30
+
31
+ if use_flash_attn:
32
+ kwargs['attn_implementation'] = 'flash_attention_2'
33
+
34
+
35
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
36
+
37
+ model = blip3oQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=torch.float16).to('cuda:0')
38
+
39
+ image_processor = None
40
+
41
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
42
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
43
+ if mm_use_im_patch_token:
44
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
45
+ if mm_use_im_start_end:
46
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
47
+ model.resize_token_embeddings(len(tokenizer))
48
+
49
+ if hasattr(model.config, "max_sequence_length"):
50
+ context_len = model.config.max_sequence_length
51
+ else:
52
+ context_len = 2048
53
+
54
+ return tokenizer, model, context_len
blip3o/model/consolidate.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from blip3o.model import *
6
+ from blip3o.model.utils import auto_upgrade
7
+
8
+
9
+ def consolidate_ckpt(src_path, dst_path):
10
+ print("Loading model")
11
+ auto_upgrade(src_path)
12
+ src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
13
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
14
+ src_model.save_pretrained(dst_path)
15
+ src_tokenizer.save_pretrained(dst_path)
16
+
17
+
18
+ if __name__ == "__main__":
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--src", type=str, required=True)
21
+ parser.add_argument("--dst", type=str, required=True)
22
+
23
+ args = parser.parse_args()
24
+
25
+ consolidate_ckpt(args.src, args.dst)
blip3o/model/language_model/blip3o_llama.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+ from PIL import Image
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from transformers import AutoConfig, AutoModelForCausalLM, \
9
+ LlamaConfig, LlamaModel, LlamaForCausalLM, AutoTokenizer
10
+
11
+ from transformers.modeling_outputs import CausalLMOutputWithPast
12
+ from transformers.generation.utils import GenerateOutput
13
+
14
+ from ..blip3o_arch import blip3oMetaModel, blip3oMetaForCausalLM
15
+ from blip3o.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_IDX, DEFAULT_IM_START_TOKEN_IDX, DEFAULT_IM_END_TOKEN_IDX
16
+ import pdb
17
+ from diffusers.utils.torch_utils import randn_tensor
18
+ from diffusers.pipelines.pipeline_utils import numpy_to_pil
19
+ import numpy as np
20
+ from diffusers.models import AutoencoderKL
21
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
22
+
23
+
24
+ class blip3oConfig(LlamaConfig):
25
+ model_type = "blip3o_llama"
26
+
27
+
28
+ class blip3oLlamaModel(blip3oMetaModel, LlamaModel):
29
+ config_class = blip3oConfig
30
+
31
+ def __init__(self, config: LlamaConfig):
32
+ super(blip3oLlamaModel, self).__init__(config)
33
+
34
+
35
+ class blip3oLlamaForCausalLM(LlamaForCausalLM, blip3oMetaForCausalLM):
36
+ config_class = blip3oConfig
37
+
38
+ def __init__(self, config):
39
+ super(LlamaForCausalLM, self).__init__(config)
40
+ self.model = blip3oLlamaModel(config)
41
+ self.pretraining_tp = config.pretraining_tp
42
+ self.vocab_size = config.vocab_size
43
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
44
+ self.dist = None
45
+
46
+ # Initialize weights and apply final processing
47
+ self.post_init()
48
+
49
+ def get_model(self):
50
+ return self.model
51
+
52
+ def forward(
53
+ self,
54
+ input_ids: torch.LongTensor = None,
55
+ attention_mask: Optional[torch.Tensor] = None,
56
+ position_ids: Optional[torch.LongTensor] = None,
57
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
58
+ inputs_embeds: Optional[torch.FloatTensor] = None,
59
+ labels: Optional[torch.LongTensor] = None,
60
+ ids: Optional[list] = None,
61
+ i_s_pos: Optional[list] = None,
62
+ image_type: Optional[torch.Tensor] = None,
63
+ use_cache: Optional[bool] = None,
64
+ output_attentions: Optional[bool] = None,
65
+ output_hidden_states: Optional[bool] = None,
66
+ gen_image: Optional[torch.FloatTensor] = None,
67
+ und_image: Optional[torch.FloatTensor] = None,
68
+ image_sizes: Optional[List[List[int]]] = None,
69
+ return_dict: Optional[bool] = None,
70
+ cache_position: Optional[torch.LongTensor] = None
71
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
72
+
73
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
74
+ output_hidden_states = (
75
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
76
+ )
77
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
78
+
79
+ if inputs_embeds is None:
80
+ (
81
+ input_ids,
82
+ position_ids,
83
+ attention_mask,
84
+ past_key_values,
85
+ inputs_embeds,
86
+ labels,
87
+ latents
88
+ ) = self.prepare_inputs_labels_for_multimodal(
89
+ input_ids,
90
+ position_ids,
91
+ attention_mask,
92
+ past_key_values,
93
+ labels,
94
+ gen_image,
95
+ und_image,
96
+ i_s_pos,
97
+ image_sizes
98
+ )
99
+
100
+ outputs = self.model(
101
+ input_ids=input_ids,
102
+ attention_mask=attention_mask,
103
+ position_ids=position_ids,
104
+ past_key_values=past_key_values,
105
+ inputs_embeds=inputs_embeds,
106
+ use_cache=use_cache,
107
+ output_attentions=output_attentions,
108
+ output_hidden_states=output_hidden_states,
109
+ return_dict=return_dict,
110
+ )
111
+
112
+ hidden_states = outputs[0]
113
+ logits = self.lm_head(hidden_states)
114
+ logits = logits.float()
115
+
116
+ total_loss = None
117
+ if labels is not None:
118
+ # Shift so that tokens < n predict n
119
+ shift_logits = logits[..., :-1, :].contiguous()
120
+ shift_labels = labels[..., 1:].contiguous()
121
+ # Flatten the tokens
122
+ loss_fct = torch.nn.CrossEntropyLoss()
123
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
124
+ shift_labels = shift_labels.view(-1)
125
+ # Enable model parallelism
126
+ shift_labels = shift_labels.to(shift_logits.device)
127
+ loss = loss_fct(shift_logits, shift_labels)
128
+
129
+
130
+ # compute image loss
131
+ # target_img_embeds = torch.clone(inputs_embeds.detach())[:,1:,:] # get target image emb
132
+ img_loss_funct = torch.nn.MSELoss()
133
+ # img_hidden_states = self.get_model().down_projector(hidden_states[:,-self.get_n_query():,:])
134
+ img_hidden_states = []
135
+
136
+ for b in range(hidden_states.shape[0]):
137
+ img_hidden_states.append(hidden_states[b,i_s_pos[b]:i_s_pos[b]+64,:])
138
+ img_hidden_states = torch.stack(img_hidden_states,dim=0)
139
+ img_hidden_states = self.get_model().down_projector(img_hidden_states)
140
+ # img_loss = 0.0
141
+ if latents is None:
142
+ img_loss = img_loss_funct(img_hidden_states, torch.clone(img_hidden_states.detach()))
143
+ else:
144
+ bsz = latents.shape[0]
145
+ # device = latents.device
146
+ dtype = latents.dtype
147
+ noise = torch.randn_like(latents, device=latents.device)
148
+ u = torch.rand(size=(bsz,), device="cpu")
149
+ indices = (u * self.get_model().noise_scheduler.config.num_train_timesteps).long()
150
+ timesteps = self.get_model().noise_scheduler.timesteps[indices].to(device=latents.device)
151
+ sigmas = self.get_sigmas(timesteps, latents.device, n_dim=latents.ndim, dtype=dtype)
152
+ noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
153
+ noise_pred = self.get_model().dit(
154
+ x=noisy_latents,
155
+ timestep=timesteps,
156
+ z_latents=self.mask_drop(img_hidden_states),
157
+ )
158
+ target = noise - latents
159
+ img_loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
160
+ print(f"img loss {img_loss}, text loss {loss}")
161
+ total_loss = img_loss
162
+
163
+ return CausalLMOutputWithPast(
164
+ loss=total_loss,
165
+ logits=logits,
166
+ past_key_values=outputs.past_key_values,
167
+ hidden_states=outputs.hidden_states,
168
+ attentions=outputs.attentions,
169
+ )
170
+
171
+
172
+ @torch.no_grad()
173
+ def generate(
174
+ self,
175
+ inputs: Optional[torch.Tensor] = None,
176
+ images: Optional[torch.Tensor] = None,
177
+ image_sizes: Optional[torch.Tensor] = None,
178
+ **kwargs,
179
+ ) -> Union[GenerateOutput, torch.LongTensor]:
180
+ position_ids = kwargs.pop("position_ids", None)
181
+ attention_mask = kwargs.pop("attention_mask", None)
182
+ if "inputs_embeds" in kwargs:
183
+ raise NotImplementedError("`inputs_embeds` is not supported")
184
+
185
+ if images is not None:
186
+ (
187
+ inputs,
188
+ position_ids,
189
+ attention_mask,
190
+ _,
191
+ inputs_embeds,
192
+ img_indicator,
193
+ _
194
+ ) = self.prepare_inputs_labels_for_understanding(
195
+ inputs,
196
+ position_ids,
197
+ attention_mask,
198
+ None,
199
+ None,
200
+ images,
201
+ image_sizes=image_sizes
202
+ )
203
+ else:
204
+ inputs_embeds = self.get_model().embed_tokens(inputs)
205
+
206
+ return super().generate(
207
+ position_ids=position_ids,
208
+ attention_mask=attention_mask,
209
+ inputs_embeds=inputs_embeds,
210
+ **kwargs
211
+ )
212
+
213
+ @torch.no_grad()
214
+ def generate_image(
215
+ self,
216
+ text: List[str],
217
+ tokenizer: AutoTokenizer,
218
+ image: Optional[torch.Tensor] = None,
219
+ max_var: Optional[float] = None,
220
+ # placeholder: str = DEFAULT_IMG_PLACEHOLDER,
221
+ ):
222
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler")
223
+
224
+ vision_tower = self.get_vision_tower()
225
+ mm_projector = self.get_mm_projector()
226
+ N_QUERY = self.get_n_query()
227
+
228
+ if image is not None:
229
+ # image: [Batch, 3, 448, 448]
230
+ prompt_image_embeds = vision_tower(batch_images)
231
+ num_img, _, c = prompt_image_embeds.shape # [batch, 576, 1024]
232
+ all_image_embeds = torch.clone(prompt_image_embeds).detach()
233
+ prompt_image_embeds = prompt_image_embeds.contiguous().view(-1, c)
234
+ prompt_image_embeds = mm_projector(prompt_image_embeds)
235
+
236
+ inputs = tokenizer(text, padding="longest", return_tensors="pt")
237
+ device = self.get_model().device
238
+ attention_mask = inputs.attention_mask.to(device)
239
+ input_ids = inputs.input_ids.to(device) # B x N
240
+ input_ids = torch.cat([input_ids, torch.tensor([[198]]).to(device)], dim=1)
241
+
242
+ # breakpoint()
243
+ text_embeds = self.get_model().embed_tokens(input_ids)
244
+ latent_queries = self.get_model().latent_queries.repeat(text_embeds.shape[0], 1, 1)
245
+ text_embeds = torch.cat([text_embeds, latent_queries], dim=1)
246
+ attention_mask = torch.cat([attention_mask, torch.ones_like(latent_queries[:, :, 0])], dim=1)
247
+
248
+ outputs = self.model(
249
+ inputs_embeds=text_embeds,
250
+ # img_indicator=img_indicator,
251
+ # concept_indicator=concept_indicator if self.use_concept_token else None,
252
+ attention_mask=attention_mask,
253
+ output_hidden_states=True,
254
+ return_dict=True,
255
+ )
256
+ hidden_states = outputs.hidden_states[-1][:,-N_QUERY:,:]
257
+ img_hidden_states = self.get_model().down_projector(hidden_states)
258
+ output_img = self.sample_images(img_hidden_states, scheduler)
259
+ output_img = output_img.view(1, 1792, -1).permute(0,2,1).contiguous()
260
+
261
+ return output_img
262
+
263
+ def sample_images(
264
+ self,
265
+ img_hidden_states,
266
+ scheduler,
267
+ guidance_scale: float = 3.0,
268
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
269
+ num_inference_steps: int = 30,
270
+ num_images_per_prompt: int = 1,
271
+ return_tensor=False,
272
+ **kwargs,
273
+ ):
274
+
275
+ device = img_hidden_states.device
276
+ dtype = img_hidden_states.dtype
277
+
278
+ img_hidden_states_null = torch.zeros_like(img_hidden_states, device=device, dtype=dtype)
279
+ img_hidden_states_input = torch.cat([img_hidden_states_null, img_hidden_states], 0)
280
+
281
+ batch_size = img_hidden_states.shape[0]
282
+ latent_size = self.get_model().dit.config.input_size
283
+ latent_channels = self.get_model().dit.config.in_channels
284
+
285
+ latents = randn_tensor(
286
+ shape=(batch_size * num_images_per_prompt, latent_channels, latent_size, latent_size),
287
+ generator=generator,
288
+ device=device,
289
+ dtype=dtype,
290
+ )
291
+
292
+ # set step values
293
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
294
+ scheduler.set_timesteps(num_inference_steps, sigmas=sigmas)
295
+
296
+ # Repeat z_latents and conditions for each image per prompt
297
+ img_hidden_states_input = img_hidden_states_input.repeat_interleave(num_images_per_prompt, dim=0)
298
+
299
+ for t in scheduler.timesteps:
300
+ latent_model_input = latents.repeat(2, 1, 1, 1)
301
+ if hasattr(scheduler, "scale_model_input"):
302
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
303
+
304
+ # predict noise model_output
305
+ noise_pred = self.get_model().dit(
306
+ x=latent_model_input,
307
+ timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]).to(latent_model_input.device, torch.long),
308
+ z_latents=img_hidden_states_input,
309
+ )
310
+
311
+ # perform guidance
312
+ noise_pred_uncond, noise_pred = noise_pred.chunk(2)
313
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
314
+
315
+ # compute previous image: x_t -> x_t-1
316
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
317
+
318
+ # samples = self.decode_latents(latents, return_tensor=return_tensor)
319
+ return latents
320
+
321
+ def decode_latents(self, latents, normalize=True, return_tensor=False):
322
+ if isinstance(self.get_model().vae, AutoencoderKL):
323
+ latents = latents / self.get_model().vae.config.scaling_factor
324
+ if self.get_model().vae.config.shift_factor is not None:
325
+ latents = latents + self.get_model().vae.config.shift_factor
326
+ latents = latents.to(dtype=torch.float32)
327
+ samples = self.get_model().vae.decode(latents).sample
328
+ else:
329
+ samples = self.get_model().vae.decode(latents)
330
+ if normalize:
331
+ samples = (samples / 2 + 0.5).clamp(0, 1)
332
+ else:
333
+ samples = samples.clamp(-1, 1)
334
+ if return_tensor:
335
+ return samples
336
+ samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
337
+ samples = numpy_to_pil(samples)
338
+ return samples
339
+
340
+ def prepare_and_encode_inputs(
341
+ self,
342
+ inputs: List[str | Image.Image],
343
+ tokenizer: AutoTokenizer,
344
+ do_classifier_free_guidance: bool = False,
345
+ ):
346
+ # pdb.set_trace()
347
+ device = self.get_model().device
348
+ dtype = self.get_model().dtype
349
+
350
+ has_image, has_text = False, False
351
+ text_prompt, image_prompt = "", []
352
+ img_processor = self.get_vision_tower().image_processor
353
+ negative_prompt = {}
354
+
355
+ for x in inputs:
356
+ if isinstance(x, str):
357
+ has_text = True
358
+ text_prompt += x
359
+ else:
360
+ has_image = True
361
+ text_prompt += DEFAULT_IMAGE_TOKEN
362
+ image_prompt.append(img_processor.preprocess(x, return_tensors='pt')['pixel_values'])
363
+ # pdb.set_trace()
364
+ if len(image_prompt) == 0:
365
+ image_prompt = None
366
+ else:
367
+ image_prompt = torch.cat(image_prompt)
368
+ image_prompt = image_prompt.type(dtype).to(device)
369
+
370
+ if has_image and not has_text:
371
+ prompt = self.encode_images(image_prompt)
372
+ # pdb.set_trace()
373
+ if do_classifier_free_guidance:
374
+ key = "[NULL_IMAGE]"
375
+ if key not in negative_prompt:
376
+ negative_image = torch.zeros_like(image_prompt)
377
+ negative_prompt[key] = self.encode_images(negative_image)
378
+ prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
379
+ else:
380
+ prompt = self.generate_image(text=[text_prompt], image=image_prompt, tokenizer=tokenizer)
381
+ if do_classifier_free_guidance:
382
+ key = ""
383
+ if key not in negative_prompt:
384
+ negative_prompt[key] = self.generate_image(text=[""], tokenizer=tokenizer)
385
+ prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
386
+
387
+ gen_pooling = self.get_gen_pooling()
388
+ n_query = self.get_n_query()
389
+ num_img, _, c = prompt.shape
390
+ if 'pool2d' in gen_pooling and has_text and not 'early' in gen_pooling:
391
+ stride = int(gen_pooling.split('_')[1])
392
+ sqrt_n = int(n_query**0.5)
393
+ prompt = prompt.permute(0, 2, 1).reshape(num_img, -1, sqrt_n, sqrt_n)
394
+ prompt = F.avg_pool2d(prompt, kernel_size=(stride, stride), stride=stride)
395
+ prompt = prompt.reshape(num_img, c, -1).permute(0,2,1)
396
+ return prompt
397
+
398
+
399
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
400
+ inputs_embeds=None, **kwargs):
401
+ images = kwargs.pop("images", None)
402
+ image_sizes = kwargs.pop("image_sizes", None)
403
+ inputs = super().prepare_inputs_for_generation(
404
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
405
+ )
406
+ if images is not None:
407
+ inputs['images'] = images
408
+ if image_sizes is not None:
409
+ inputs['image_sizes'] = image_sizes
410
+ return inputs
411
+
412
+ AutoConfig.register("blip3o_llama", blip3oConfig)
413
+ AutoModelForCausalLM.register(blip3oConfig, blip3oLlamaForCausalLM)
blip3o/model/language_model/blip3o_qwen.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union, Dict
2
+ import torch
3
+ import torch.nn as nn
4
+ from PIL import Image
5
+ import torch.nn.functional as F
6
+
7
+
8
+ import transformers
9
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
10
+
11
+ from transformers.modeling_outputs import CausalLMOutputWithPast
12
+ from transformers.generation.utils import GenerateOutput
13
+
14
+ from blip3o.model.blip3o_arch import blip3oMetaModel, blip3oMetaForCausalLM
15
+
16
+ from transformers import Qwen2_5_VLConfig, Qwen2_5_VLModel, Qwen2_5_VLForConditionalGeneration
17
+
18
+ from blip3o.constants import UND_IMAGE_TOKEN_IDX
19
+
20
+
21
+
22
+ from diffusers.utils.torch_utils import randn_tensor
23
+ from diffusers.pipelines.pipeline_utils import numpy_to_pil
24
+ import numpy as np
25
+ from diffusers.models import AutoencoderKL
26
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
27
+
28
+
29
+ class blip3oQwenConfig(Qwen2_5_VLConfig):
30
+ model_type = "blip3o_qwen"
31
+
32
+
33
+ class blip3oQwenModel(blip3oMetaModel, Qwen2_5_VLModel):
34
+ config_class = blip3oQwenConfig
35
+
36
+ def __init__(self, config: Qwen2_5_VLConfig):
37
+ super(blip3oQwenModel, self).__init__(config)
38
+
39
+
40
+ class blip3oQwenForCausalLM(Qwen2_5_VLForConditionalGeneration, blip3oMetaForCausalLM):
41
+ config_class = blip3oQwenConfig
42
+
43
+ def __init__(self, config):
44
+ Qwen2_5_VLForConditionalGeneration.__init__(self, config)
45
+ config.model_type = "blip3o_qwen"
46
+
47
+ self.model = blip3oQwenModel(config)
48
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49
+ # Initialize weights and apply final processing
50
+ self.post_init()
51
+
52
+ def get_model(self):
53
+ return self.model
54
+
55
+
56
+ def forward(
57
+ self,
58
+ input_ids: torch.LongTensor = None,
59
+ attention_mask: Optional[torch.Tensor] = None,
60
+ position_ids: Optional[torch.LongTensor] = None,
61
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
62
+ inputs_embeds: Optional[torch.FloatTensor] = None,
63
+ labels: Optional[torch.LongTensor] = None,
64
+ ids: Optional[list] = None,
65
+ i_s_pos: Optional[list] = None,
66
+ use_cache: Optional[bool] = None,
67
+ output_attentions: Optional[bool] = None,
68
+ output_hidden_states: Optional[bool] = None,
69
+ gen_image: Optional[torch.FloatTensor] = None,
70
+ und_image: Optional[torch.FloatTensor] = None,
71
+ grid_thw: Optional[torch.FloatTensor] = None,
72
+ image_sizes: Optional[List[List[int]]] = None,
73
+ return_dict: Optional[bool] = None,
74
+ cache_position: Optional[torch.LongTensor] = None
75
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
76
+
77
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
78
+ output_hidden_states = (
79
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
80
+ )
81
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
82
+
83
+ if inputs_embeds is None:
84
+ (
85
+ input_ids,
86
+ position_ids,
87
+ attention_mask,
88
+ past_key_values,
89
+ inputs_embeds,
90
+ labels,
91
+ latents
92
+ ) = self.prepare_inputs_labels_for_multimodal(
93
+ input_ids,
94
+ position_ids,
95
+ attention_mask,
96
+ past_key_values,
97
+ labels,
98
+ gen_image,
99
+ und_image,
100
+ grid_thw,
101
+ i_s_pos,
102
+ image_sizes
103
+ )
104
+
105
+ outputs = self.model(
106
+ input_ids=input_ids,
107
+ attention_mask=attention_mask,
108
+ position_ids=position_ids,
109
+ past_key_values=past_key_values,
110
+ inputs_embeds=inputs_embeds,
111
+ use_cache=use_cache,
112
+ output_attentions=output_attentions,
113
+ output_hidden_states=output_hidden_states,
114
+ return_dict=return_dict,
115
+ )
116
+
117
+ hidden_states = outputs[0]
118
+ logits = self.lm_head(hidden_states)
119
+ logits = logits.float()
120
+
121
+ total_loss = None
122
+ if labels is not None:
123
+ # Shift so that tokens < n predict n
124
+ shift_logits = logits[..., :-1, :].contiguous()
125
+ shift_labels = labels[..., 1:].contiguous()
126
+ # Flatten the tokens
127
+ loss_fct = torch.nn.CrossEntropyLoss()
128
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
129
+ shift_labels = shift_labels.view(-1)
130
+ # Enable model parallelism
131
+ shift_labels = shift_labels.to(shift_logits.device)
132
+ loss = loss_fct(shift_logits, shift_labels)
133
+
134
+
135
+ # compute image loss
136
+ # target_img_embeds = torch.clone(inputs_embeds.detach())[:,1:,:] # get target image emb
137
+ img_loss_funct = torch.nn.MSELoss()
138
+ # img_hidden_states = self.get_model().down_projector(hidden_states[:,-self.get_n_query():,:])
139
+ img_hidden_states = []
140
+
141
+ for b in range(hidden_states.shape[0]):
142
+ img_hidden_states.append(hidden_states[b,i_s_pos[b]:i_s_pos[b]+64,:])
143
+ img_hidden_states = torch.stack(img_hidden_states,dim=0)
144
+ img_hidden_states = self.get_model().down_projector(img_hidden_states)
145
+ # img_loss = 0.0
146
+ if latents is None:
147
+ img_loss = img_loss_funct(img_hidden_states, torch.clone(img_hidden_states.detach()))
148
+ else:
149
+ bsz = latents.shape[0]
150
+ # device = latents.device
151
+ dtype = latents.dtype
152
+ noise = torch.randn_like(latents, device=latents.device)
153
+ u = torch.rand(size=(bsz,), device="cpu")
154
+ indices = (u * self.get_model().noise_scheduler.config.num_train_timesteps).long()
155
+ timesteps = self.get_model().noise_scheduler.timesteps[indices].to(device=latents.device)
156
+ sigmas = self.get_sigmas(timesteps, latents.device, n_dim=latents.ndim, dtype=dtype)
157
+ noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
158
+ noise_pred = self.get_model().dit(
159
+ x=noisy_latents,
160
+ timestep=timesteps,
161
+ z_latents=self.mask_drop(img_hidden_states),
162
+ )
163
+ target = noise - latents
164
+ img_loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
165
+ print(f"img loss {img_loss}")
166
+ total_loss = img_loss
167
+
168
+ return CausalLMOutputWithPast(
169
+ loss=total_loss,
170
+ logits=logits,
171
+ past_key_values=outputs.past_key_values,
172
+ hidden_states=outputs.hidden_states,
173
+ attentions=outputs.attentions,
174
+ )
175
+
176
+
177
+ @torch.no_grad()
178
+ def generate(
179
+ self,
180
+ inputs: Optional[torch.Tensor] = None,
181
+ images: Optional[torch.Tensor] = None,
182
+ image_sizes: Optional[torch.Tensor] = None,
183
+ **kwargs,
184
+ ) -> Union[GenerateOutput, torch.LongTensor]:
185
+ position_ids = kwargs.pop("position_ids", None)
186
+ attention_mask = kwargs.pop("attention_mask", None)
187
+ if "inputs_embeds" in kwargs:
188
+ raise NotImplementedError("`inputs_embeds` is not supported")
189
+
190
+ if images is not None:
191
+ (
192
+ inputs,
193
+ position_ids,
194
+ attention_mask,
195
+ _,
196
+ inputs_embeds,
197
+ img_indicator,
198
+ _
199
+ ) = self.prepare_inputs_labels_for_understanding(
200
+ inputs,
201
+ position_ids,
202
+ attention_mask,
203
+ None,
204
+ None,
205
+ images,
206
+ image_sizes=image_sizes
207
+ )
208
+ else:
209
+ inputs_embeds = self.get_model().embed_tokens(inputs)
210
+
211
+ return super().generate(
212
+ position_ids=position_ids,
213
+ attention_mask=attention_mask,
214
+ inputs_embeds=inputs_embeds,
215
+ **kwargs
216
+ )
217
+
218
+ @torch.no_grad()
219
+ def generate_image(
220
+ self,
221
+ text: List[str],
222
+ tokenizer: AutoTokenizer,
223
+ pixel_values: Optional[torch.Tensor] = None,
224
+ image_grid_thw: Optional[torch.Tensor] = None,
225
+ max_var: Optional[float] = None,
226
+ # placeholder: str = DEFAULT_IMG_PLACEHOLDER,
227
+ ):
228
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler")
229
+
230
+
231
+ N_QUERY = self.get_n_query()
232
+ inputs = tokenizer(text, padding="longest", return_tensors="pt")
233
+ device = self.get_model().device
234
+ attention_mask = inputs.attention_mask.to(device)
235
+ input_ids = inputs.input_ids.to(device) # B x N
236
+ input_ids = torch.cat([input_ids, torch.tensor([[151665]]).to(device)], dim=1)
237
+ # breakpoint()
238
+
239
+
240
+ text_embeds = self.get_model().embed_tokens(input_ids)
241
+ latent_queries = self.get_model().latent_queries.repeat(text_embeds.shape[0], 1, 1)
242
+
243
+
244
+ if pixel_values is not None:
245
+ und_image_idx = (input_ids == UND_IMAGE_TOKEN_IDX)
246
+ pixel_values = pixel_values.type(self.visual.dtype)
247
+ und_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
248
+ text_embeds[und_image_idx] = und_image_embeds.to(text_embeds.device)[:und_image_idx.sum(), :]
249
+
250
+
251
+ text_embeds = torch.cat([text_embeds, latent_queries], dim=1)
252
+ attention_mask = torch.cat([attention_mask, torch.ones_like(latent_queries[:, :, 0])], dim=1)
253
+
254
+
255
+ outputs = self.model(
256
+ inputs_embeds=text_embeds,
257
+ attention_mask=attention_mask,
258
+ output_hidden_states=True,
259
+ return_dict=True,
260
+ )
261
+ hidden_states = outputs.hidden_states[-1][:,-N_QUERY:,:]
262
+ img_hidden_states = hidden_states
263
+ output_img = self.sample_images(img_hidden_states, scheduler)
264
+ output_img = output_img.view(1, 1792, -1).permute(0,2,1).contiguous()
265
+
266
+ return output_img
267
+
268
+ def sample_images(
269
+ self,
270
+ img_hidden_states,
271
+ scheduler,
272
+ guidance_scale: float = 3.0,
273
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
274
+ num_inference_steps: int = 30,
275
+ num_images_per_prompt: int = 1,
276
+ return_tensor=False,
277
+ **kwargs,
278
+ ):
279
+
280
+ device = img_hidden_states.device
281
+ dtype = img_hidden_states.dtype
282
+
283
+
284
+ img_hidden_states_null = torch.zeros_like(img_hidden_states, device=device, dtype=dtype)
285
+ img_hidden_states_input = torch.cat([img_hidden_states_null, img_hidden_states], 0)
286
+
287
+ batch_size = img_hidden_states.shape[0]
288
+ latent_size = self.get_model().dit.config.input_size
289
+ latent_channels = self.get_model().dit.config.in_channels
290
+
291
+ latents = randn_tensor(
292
+ shape=(batch_size * num_images_per_prompt, latent_channels, latent_size, latent_size),
293
+ generator=generator,
294
+ device=device,
295
+ dtype=dtype,
296
+ )
297
+
298
+ # set step values
299
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
300
+ scheduler.set_timesteps(num_inference_steps, sigmas=sigmas)
301
+
302
+ # Repeat z_latents and conditions for each image per prompt
303
+ img_hidden_states_input = img_hidden_states_input.repeat_interleave(num_images_per_prompt, dim=0)
304
+
305
+ for t in scheduler.timesteps:
306
+ latent_model_input = latents.repeat(2, 1, 1, 1)
307
+ if hasattr(scheduler, "scale_model_input"):
308
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
309
+
310
+ # predict noise model_output
311
+ noise_pred = self.get_model().dit(
312
+ x=latent_model_input,
313
+ timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]).to(latent_model_input.device, torch.long),
314
+ z_latents=img_hidden_states_input,
315
+ )
316
+
317
+ # perform guidance
318
+ noise_pred_uncond, noise_pred = noise_pred.chunk(2)
319
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
320
+
321
+ # compute previous image: x_t -> x_t-1
322
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
323
+
324
+ # samples = self.decode_latents(latents, return_tensor=return_tensor)
325
+ # breakpoint()
326
+ return latents
327
+
328
+ def decode_latents(self, latents, normalize=True, return_tensor=False):
329
+ if isinstance(self.get_model().vae, AutoencoderKL):
330
+ latents = latents / self.get_model().vae.config.scaling_factor
331
+ if self.get_model().vae.config.shift_factor is not None:
332
+ latents = latents + self.get_model().vae.config.shift_factor
333
+ latents = latents.to(dtype=torch.float32)
334
+ samples = self.get_model().vae.decode(latents).sample
335
+ else:
336
+ samples = self.get_model().vae.decode(latents)
337
+ if normalize:
338
+ samples = (samples / 2 + 0.5).clamp(0, 1)
339
+ else:
340
+ samples = samples.clamp(-1, 1)
341
+ if return_tensor:
342
+ return samples
343
+ samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
344
+ samples = numpy_to_pil(samples)
345
+ return samples
346
+
347
+ def prepare_and_encode_inputs(
348
+ self,
349
+ inputs: List[str | Image.Image],
350
+ tokenizer: AutoTokenizer,
351
+ do_classifier_free_guidance: bool = False,
352
+ ):
353
+ # pdb.set_trace()
354
+ device = self.get_model().device
355
+ dtype = self.get_model().dtype
356
+
357
+ has_image, has_text = False, False
358
+ text_prompt, image_prompt = "", []
359
+ img_processor = self.get_vision_tower().image_processor
360
+ negative_prompt = {}
361
+
362
+ for x in inputs:
363
+ if isinstance(x, str):
364
+ has_text = True
365
+ text_prompt += x
366
+ else:
367
+ has_image = True
368
+ text_prompt += DEFAULT_IMAGE_TOKEN
369
+ image_prompt.append(img_processor.preprocess(x, return_tensors='pt')['pixel_values'])
370
+ # pdb.set_trace()
371
+ if len(image_prompt) == 0:
372
+ image_prompt = None
373
+ else:
374
+ image_prompt = torch.cat(image_prompt)
375
+ image_prompt = image_prompt.type(dtype).to(device)
376
+
377
+ if has_image and not has_text:
378
+ prompt = self.encode_images(image_prompt)
379
+ # pdb.set_trace()
380
+ if do_classifier_free_guidance:
381
+ key = "[NULL_IMAGE]"
382
+ if key not in negative_prompt:
383
+ negative_image = torch.zeros_like(image_prompt)
384
+ negative_prompt[key] = self.encode_images(negative_image)
385
+ prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
386
+ else:
387
+ prompt = self.generate_image(text=[text_prompt], image=image_prompt, tokenizer=tokenizer)
388
+ if do_classifier_free_guidance:
389
+ key = ""
390
+ if key not in negative_prompt:
391
+ negative_prompt[key] = self.generate_image(text=[""], tokenizer=tokenizer)
392
+ prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
393
+
394
+ gen_pooling = self.get_gen_pooling()
395
+ n_query = self.get_n_query()
396
+ num_img, _, c = prompt.shape
397
+ if 'pool2d' in gen_pooling and has_text and not 'early' in gen_pooling:
398
+ stride = int(gen_pooling.split('_')[1])
399
+ sqrt_n = int(n_query**0.5)
400
+ prompt = prompt.permute(0, 2, 1).reshape(num_img, -1, sqrt_n, sqrt_n)
401
+ prompt = F.avg_pool2d(prompt, kernel_size=(stride, stride), stride=stride)
402
+ prompt = prompt.reshape(num_img, c, -1).permute(0,2,1)
403
+ return prompt
404
+
405
+
406
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
407
+ inputs_embeds=None, **kwargs):
408
+ images = kwargs.pop("images", None)
409
+ image_sizes = kwargs.pop("image_sizes", None)
410
+ inputs = super().prepare_inputs_for_generation(
411
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
412
+ )
413
+ if images is not None:
414
+ inputs['images'] = images
415
+ if image_sizes is not None:
416
+ inputs['image_sizes'] = image_sizes
417
+ return inputs
418
+
419
+ AutoConfig.register("blip3o_qwen", blip3oQwenConfig)
420
+ AutoModelForCausalLM.register(blip3oQwenConfig, blip3oQwenForCausalLM)
blip3o/model/lumina_nextdit2d.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
20
+ from diffusers.models.attention import LuminaFeedForward
21
+ from diffusers.models.attention_processor import Attention, LuminaAttnProcessor2_0
22
+ from diffusers.models.embeddings import LuminaCombinedTimestepCaptionEmbedding, LuminaPatchEmbed, PixArtAlphaTextProjection
23
+
24
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from diffusers.models.normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
27
+ from diffusers.utils import is_torch_version, logging
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+
32
+ class LuminaNextDiTBlock(nn.Module):
33
+ """
34
+ A LuminaNextDiTBlock for LuminaNextDiT2DModel.
35
+
36
+ Parameters:
37
+ dim (`int`): Embedding dimension of the input features.
38
+ num_attention_heads (`int`): Number of attention heads.
39
+ num_kv_heads (`int`):
40
+ Number of attention heads in key and value features (if using GQA), or set to None for the same as query.
41
+ multiple_of (`int`): The number of multiple of ffn layer.
42
+ ffn_dim_multiplier (`float`): The multipier factor of ffn layer dimension.
43
+ norm_eps (`float`): The eps for norm layer.
44
+ qk_norm (`bool`): normalization for query and key.
45
+ cross_attention_dim (`int`): Cross attention embedding dimension of the input text prompt hidden_states.
46
+ norm_elementwise_affine (`bool`, *optional*, defaults to True),
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ dim: int,
52
+ num_attention_heads: int,
53
+ num_kv_heads: int,
54
+ multiple_of: int,
55
+ ffn_dim_multiplier: float,
56
+ norm_eps: float,
57
+ qk_norm: bool,
58
+ cross_attention_dim: int,
59
+ norm_elementwise_affine: bool = True,
60
+ ) -> None:
61
+ super().__init__()
62
+ self.head_dim = dim // num_attention_heads
63
+
64
+ self.gate = nn.Parameter(torch.zeros([num_attention_heads]))
65
+
66
+ # Self-attention
67
+ self.attn1 = Attention(
68
+ query_dim=dim,
69
+ cross_attention_dim=None,
70
+ dim_head=dim // num_attention_heads,
71
+ qk_norm="layer_norm_across_heads" if qk_norm else None,
72
+ heads=num_attention_heads,
73
+ kv_heads=num_kv_heads,
74
+ eps=1e-5,
75
+ bias=False,
76
+ out_bias=False,
77
+ processor=LuminaAttnProcessor2_0(),
78
+ )
79
+ self.attn1.to_out = nn.Identity()
80
+
81
+ # Cross-attention
82
+ self.attn2 = Attention(
83
+ query_dim=dim,
84
+ cross_attention_dim=cross_attention_dim,
85
+ dim_head=dim // num_attention_heads,
86
+ qk_norm="layer_norm_across_heads" if qk_norm else None,
87
+ heads=num_attention_heads,
88
+ kv_heads=num_kv_heads,
89
+ eps=1e-5,
90
+ bias=False,
91
+ out_bias=False,
92
+ processor=LuminaAttnProcessor2_0(),
93
+ )
94
+
95
+ self.feed_forward = LuminaFeedForward(
96
+ dim=dim,
97
+ inner_dim=4 * dim,
98
+ multiple_of=multiple_of,
99
+ ffn_dim_multiplier=ffn_dim_multiplier,
100
+ )
101
+
102
+ self.norm1 = LuminaRMSNormZero(
103
+ embedding_dim=dim,
104
+ norm_eps=norm_eps,
105
+ norm_elementwise_affine=norm_elementwise_affine,
106
+ )
107
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
108
+
109
+ self.norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
110
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
111
+
112
+ self.norm1_context = RMSNorm(cross_attention_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
113
+
114
+ def forward(
115
+ self,
116
+ hidden_states: torch.Tensor,
117
+ attention_mask: torch.Tensor,
118
+ image_rotary_emb: torch.Tensor,
119
+ encoder_hidden_states: torch.Tensor,
120
+ encoder_mask: torch.Tensor,
121
+ temb: torch.Tensor,
122
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
123
+ ):
124
+ """
125
+ Perform a forward pass through the LuminaNextDiTBlock.
126
+
127
+ Parameters:
128
+ hidden_states (`torch.Tensor`): The input of hidden_states for LuminaNextDiTBlock.
129
+ attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask.
130
+ image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies.
131
+ encoder_hidden_states: (`torch.Tensor`): The hidden_states of text prompt are processed by Gemma encoder.
132
+ encoder_mask (`torch.Tensor`): The hidden_states of text prompt attention mask.
133
+ temb (`torch.Tensor`): Timestep embedding with text prompt embedding.
134
+ cross_attention_kwargs (`Dict[str, Any]`): kwargs for cross attention.
135
+ """
136
+ residual = hidden_states
137
+
138
+ # Self-attention
139
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
140
+ self_attn_output = self.attn1(
141
+ hidden_states=norm_hidden_states,
142
+ encoder_hidden_states=norm_hidden_states,
143
+ attention_mask=attention_mask,
144
+ query_rotary_emb=image_rotary_emb,
145
+ key_rotary_emb=image_rotary_emb,
146
+ **cross_attention_kwargs,
147
+ )
148
+
149
+ # Cross-attention
150
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
151
+ cross_attn_output = self.attn2(
152
+ hidden_states=norm_hidden_states,
153
+ encoder_hidden_states=norm_encoder_hidden_states,
154
+ attention_mask=encoder_mask,
155
+ query_rotary_emb=image_rotary_emb,
156
+ key_rotary_emb=None,
157
+ **cross_attention_kwargs,
158
+ )
159
+ cross_attn_output = cross_attn_output * self.gate.tanh().view(1, 1, -1, 1)
160
+ mixed_attn_output = self_attn_output + cross_attn_output
161
+ mixed_attn_output = mixed_attn_output.flatten(-2)
162
+ # linear proj
163
+ hidden_states = self.attn2.to_out[0](mixed_attn_output)
164
+
165
+ hidden_states = residual + gate_msa.unsqueeze(1).tanh() * self.norm2(hidden_states)
166
+
167
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
168
+
169
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
170
+
171
+ return hidden_states
172
+
173
+
174
+ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
175
+ """
176
+ LuminaNextDiT: Diffusion model with a Transformer backbone.
177
+
178
+ Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
179
+
180
+ Parameters:
181
+ sample_size (`int`): The width of the latent images. This is fixed during training since
182
+ it is used to learn a number of position embeddings.
183
+ patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
184
+ The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
185
+ in_channels (`int`, *optional*, defaults to 4):
186
+ The number of input channels for the model. Typically, this matches the number of channels in the input
187
+ images.
188
+ hidden_size (`int`, *optional*, defaults to 4096):
189
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
190
+ hidden representations.
191
+ num_layers (`int`, *optional*, default to 32):
192
+ The number of layers in the model. This defines the depth of the neural network.
193
+ num_attention_heads (`int`, *optional*, defaults to 32):
194
+ The number of attention heads in each attention layer. This parameter specifies how many separate attention
195
+ mechanisms are used.
196
+ num_kv_heads (`int`, *optional*, defaults to 8):
197
+ The number of key-value heads in the attention mechanism, if different from the number of attention heads.
198
+ If None, it defaults to num_attention_heads.
199
+ multiple_of (`int`, *optional*, defaults to 256):
200
+ A factor that the hidden size should be a multiple of. This can help optimize certain hardware
201
+ configurations.
202
+ ffn_dim_multiplier (`float`, *optional*):
203
+ A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
204
+ the model configuration.
205
+ norm_eps (`float`, *optional*, defaults to 1e-5):
206
+ A small value added to the denominator for numerical stability in normalization layers.
207
+ learn_sigma (`bool`, *optional*, defaults to True):
208
+ Whether the model should learn the sigma parameter, which might be related to uncertainty or variance in
209
+ predictions.
210
+ qk_norm (`bool`, *optional*, defaults to True):
211
+ Indicates if the queries and keys in the attention mechanism should be normalized.
212
+ cross_attention_dim (`int`, *optional*, defaults to 2048):
213
+ The dimensionality of the text embeddings. This parameter defines the size of the text representations used
214
+ in the model.
215
+ scaling_factor (`float`, *optional*, defaults to 1.0):
216
+ A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
217
+ overall scale of the model's operations.
218
+ """
219
+
220
+ _supports_gradient_checkpointing = True
221
+ _no_split_modules = ["LuminaNextDiTBlock"]
222
+
223
+ @register_to_config
224
+ def __init__(
225
+ self,
226
+ sample_size: int = 128,
227
+ patch_size: Optional[int] = 2,
228
+ in_channels: Optional[int] = 4,
229
+ hidden_size: Optional[int] = 2304,
230
+ num_layers: Optional[int] = 32, # 32
231
+ num_attention_heads: Optional[int] = 32, # 32
232
+ num_kv_heads: Optional[int] = None,
233
+ multiple_of: Optional[int] = 256,
234
+ ffn_dim_multiplier: Optional[float] = None,
235
+ norm_eps: Optional[float] = 1e-5,
236
+ learn_sigma: Optional[bool] = True,
237
+ qk_norm: Optional[bool] = True,
238
+ cross_attention_dim: Optional[int] = 2048,
239
+ scaling_factor: Optional[float] = 1.0,
240
+ ) -> None:
241
+ super().__init__()
242
+ self.sample_size = sample_size
243
+ self.patch_size = patch_size
244
+ self.in_channels = in_channels
245
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
246
+ self.hidden_size = hidden_size
247
+ self.num_attention_heads = num_attention_heads
248
+ self.head_dim = hidden_size // num_attention_heads
249
+ self.scaling_factor = scaling_factor
250
+ self.gradient_checkpointing = False
251
+
252
+ self.caption_projection = PixArtAlphaTextProjection(in_features=cross_attention_dim, hidden_size=hidden_size)
253
+ self.patch_embedder = LuminaPatchEmbed(patch_size=patch_size, in_channels=in_channels, embed_dim=hidden_size, bias=True)
254
+
255
+ self.time_caption_embed = LuminaCombinedTimestepCaptionEmbedding(hidden_size=min(hidden_size, 1024), cross_attention_dim=hidden_size)
256
+
257
+ self.layers = nn.ModuleList(
258
+ [
259
+ LuminaNextDiTBlock(
260
+ hidden_size,
261
+ num_attention_heads,
262
+ num_kv_heads,
263
+ multiple_of,
264
+ ffn_dim_multiplier,
265
+ norm_eps,
266
+ qk_norm,
267
+ hidden_size,
268
+ )
269
+ for _ in range(num_layers)
270
+ ]
271
+ )
272
+ self.norm_out = LuminaLayerNormContinuous(
273
+ embedding_dim=hidden_size,
274
+ conditioning_embedding_dim=min(hidden_size, 1024),
275
+ elementwise_affine=False,
276
+ eps=1e-6,
277
+ bias=True,
278
+ out_dim=patch_size * patch_size * self.out_channels,
279
+ )
280
+ # self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels)
281
+
282
+ assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"
283
+
284
+ def _set_gradient_checkpointing(self, module, value=False):
285
+ if hasattr(module, "gradient_checkpointing"):
286
+ module.gradient_checkpointing = value
287
+
288
+ def forward(
289
+ self,
290
+ hidden_states: torch.Tensor,
291
+ timestep: torch.Tensor,
292
+ encoder_hidden_states: torch.Tensor,
293
+ encoder_mask: torch.Tensor,
294
+ image_rotary_emb: torch.Tensor,
295
+ cross_attention_kwargs: Dict[str, Any] = None,
296
+ return_dict=True,
297
+ ) -> torch.Tensor:
298
+ """
299
+ Forward pass of LuminaNextDiT.
300
+
301
+ Parameters:
302
+ hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W).
303
+ timestep (torch.Tensor): Tensor of diffusion timesteps of shape (N,).
304
+ encoder_hidden_states (torch.Tensor): Tensor of caption features of shape (N, D).
305
+ encoder_mask (torch.Tensor): Tensor of caption masks of shape (N, L).
306
+ """
307
+ hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb)
308
+ image_rotary_emb = image_rotary_emb.to(hidden_states.device)
309
+ # breakpoint()
310
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
311
+ temb = self.time_caption_embed(timestep, encoder_hidden_states, encoder_mask)
312
+
313
+ encoder_mask = encoder_mask.bool()
314
+
315
+ for layer in self.layers:
316
+ if self.training and self.gradient_checkpointing:
317
+
318
+ def create_custom_forward(module, return_dict=None):
319
+ def custom_forward(*inputs):
320
+ if return_dict is not None:
321
+ return module(*inputs, return_dict=return_dict)
322
+ else:
323
+ return module(*inputs)
324
+
325
+ return custom_forward
326
+
327
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
328
+ hidden_states = torch.utils.checkpoint.checkpoint(
329
+ create_custom_forward(layer),
330
+ hidden_states,
331
+ mask,
332
+ image_rotary_emb,
333
+ encoder_hidden_states,
334
+ encoder_mask,
335
+ temb,
336
+ cross_attention_kwargs,
337
+ **ckpt_kwargs,
338
+ )
339
+ else:
340
+ hidden_states = layer(
341
+ hidden_states,
342
+ mask,
343
+ image_rotary_emb,
344
+ encoder_hidden_states,
345
+ encoder_mask,
346
+ temb=temb,
347
+ cross_attention_kwargs=cross_attention_kwargs,
348
+ )
349
+
350
+ hidden_states = self.norm_out(hidden_states, temb)
351
+
352
+ # unpatchify
353
+ height_tokens = width_tokens = self.patch_size
354
+ height, width = img_size[0]
355
+ batch_size = hidden_states.size(0)
356
+ sequence_length = (height // height_tokens) * (width // width_tokens)
357
+ hidden_states = hidden_states[:, :sequence_length].view(
358
+ batch_size, height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels
359
+ )
360
+ output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
361
+
362
+ if not return_dict:
363
+ return (output,)
364
+
365
+ return Transformer2DModelOutput(sample=output)
blip3o/model/make_delta.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from tqdm import tqdm
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from blip3o.model.utils import auto_upgrade
7
+
8
+
9
+ def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
10
+ print("Loading base model")
11
+ base = AutoModelForCausalLM.from_pretrained(
12
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
13
+
14
+ print("Loading target model")
15
+ auto_upgrade(target_model_path)
16
+ target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+
18
+ print("Calculating delta")
19
+ for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
20
+ if name not in base.state_dict():
21
+ assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
22
+ continue
23
+ if param.data.shape == base.state_dict()[name].shape:
24
+ param.data -= base.state_dict()[name]
25
+ else:
26
+ assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
27
+ bparam = base.state_dict()[name]
28
+ param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
29
+
30
+ print("Saving delta")
31
+ if hub_repo_id:
32
+ kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
33
+ else:
34
+ kwargs = {}
35
+ target.save_pretrained(delta_path, **kwargs)
36
+ target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
37
+ target_tokenizer.save_pretrained(delta_path, **kwargs)
38
+
39
+
40
+ if __name__ == "__main__":
41
+ parser = argparse.ArgumentParser()
42
+ parser.add_argument("--base-model-path", type=str, required=True)
43
+ parser.add_argument("--target-model-path", type=str, required=True)
44
+ parser.add_argument("--delta-path", type=str, required=True)
45
+ parser.add_argument("--hub-repo-id", type=str, default=None)
46
+ args = parser.parse_args()
47
+
48
+ make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
blip3o/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .clip_encoder import CLIPVisionTower
3
+ from .imagebind import ImageBindWrapper
4
+ from .open_clip_encoder import OpenCLIPVisionTower
5
+ from .siglip_encoder import SigLipVisionTower
6
+ from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
7
+
8
+ from .eva_clip.eva_clip_encoder import EvaClipVisionTower
9
+ from .dev_eva_clip.eva_vit import EvaViTWrapper
10
+
11
+ from blip3o.model.nextdit_crossattn import NextDiTCrossAttnConfig, NextDiTCrossAttn
12
+ from diffusers.models import AutoencoderKL
13
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
14
+
15
+
16
+ def build_vision_tower(vision_tower_cfg, **kwargs):
17
+ vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
18
+ is_absolute_path_exists = os.path.exists(vision_tower)
19
+ use_s2 = getattr(vision_tower_cfg, 's2', False)
20
+ if "siglip" in vision_tower:
21
+ return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
22
+ if "eva" in vision_tower:
23
+ return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
24
+ if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
25
+ if use_s2:
26
+ return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
27
+ else:
28
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
29
+
30
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
31
+
32
+
33
+
34
+
35
+ def build_gen_vision_tower(vision_tower_cfg, **kwargs):
36
+ vision_tower = getattr(vision_tower_cfg, 'gen_vision_tower')
37
+ is_absolute_path_exists = os.path.exists(vision_tower)
38
+ use_s2 = getattr(vision_tower_cfg, 's2', False)
39
+ if "siglip" in vision_tower:
40
+ return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
41
+ if "eva" in vision_tower:
42
+ return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
43
+ if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
44
+ if use_s2:
45
+ return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
46
+ else:
47
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
48
+
49
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
50
+
51
+
52
+
53
+ def build_dit(vision_tower_cfg, **kwargs):
54
+ vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae")
55
+ # vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
56
+ dit = NextDiTCrossAttn(NextDiTCrossAttnConfig())
57
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler")
58
+ # scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler")
59
+ vae.eval()
60
+ vae.requires_grad_(False)
61
+ return dit, vae, noise_scheduler
62
+
63
+
blip3o/model/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
4
+
5
+ try:
6
+ from s2wrapper import forward as multiscale_forward
7
+ except:
8
+ pass
9
+
10
+
11
+ class CLIPVisionTower(nn.Module):
12
+ def __init__(self, vision_tower, args, delay_load=False):
13
+ super().__init__()
14
+
15
+ self.is_loaded = False
16
+
17
+ self.vision_tower_name = vision_tower
18
+ self.select_layer = args.mm_vision_select_layer
19
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
20
+
21
+ if not delay_load:
22
+ print(f"Loading vision tower: {vision_tower}")
23
+ self.load_model()
24
+ elif getattr(args, "unfreeze_mm_vision_tower", False):
25
+ # TODO: better detector is needed.
26
+ print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
27
+ self.load_model()
28
+ elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
29
+ print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
30
+ self.load_model()
31
+ else:
32
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
33
+
34
+ def load_model(self, device_map=None):
35
+ if self.is_loaded:
36
+ print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
37
+ return
38
+
39
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
40
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
41
+ self.vision_tower.requires_grad_(False)
42
+
43
+ self.is_loaded = True
44
+
45
+ def feature_select(self, image_forward_outs):
46
+ select_feature_type = self.select_feature
47
+
48
+ if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
49
+ select_every_k_layer = len(image_forward_outs.hidden_states) // 4
50
+ image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1)
51
+ select_feature_type = select_feature_type.replace("slicefour_", "")
52
+ elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]:
53
+ select_layers = [-2, -5, -8, -11, 6]
54
+ image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1)
55
+ select_feature_type = select_feature_type.replace("slice_m25811_f6_", "")
56
+ else:
57
+ image_features = image_forward_outs.hidden_states[self.select_layer]
58
+
59
+ if select_feature_type == "patch":
60
+ image_features = image_features[:, 1:]
61
+ elif select_feature_type == "cls_patch":
62
+ image_features = image_features
63
+ else:
64
+ raise ValueError(f"Unexpected select feature: {select_feature_type}")
65
+ return image_features
66
+
67
+ def forward(self, images):
68
+ if type(images) is list:
69
+ image_features = []
70
+ for image in images:
71
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
72
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
73
+ image_features.append(image_feature)
74
+ else:
75
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
76
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
77
+
78
+ return image_features
79
+
80
+ @property
81
+ def dummy_feature(self):
82
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
83
+
84
+ @property
85
+ def dtype(self):
86
+ return self.vision_tower.dtype
87
+
88
+ @property
89
+ def device(self):
90
+ return self.vision_tower.device
91
+
92
+ @property
93
+ def config(self):
94
+ if self.is_loaded:
95
+ return self.vision_tower.config
96
+ else:
97
+ return self.cfg_only
98
+
99
+ @property
100
+ def hidden_size(self):
101
+ _hidden_size = self.config.hidden_size
102
+ if "slicefour" in self.select_feature:
103
+ _hidden_size *= 4
104
+ if "slice_m25811_f6" in self.select_feature:
105
+ _hidden_size *= 5
106
+ return _hidden_size
107
+
108
+ @property
109
+ def num_patches_per_side(self):
110
+ return self.config.image_size // self.config.patch_size
111
+
112
+ @property
113
+ def num_patches(self):
114
+ _num_patches = (self.config.image_size // self.config.patch_size) ** 2
115
+ if "cls_patch" in self.select_feature:
116
+ _num_patches += 1
117
+ return _num_patches
118
+
119
+ @property
120
+ def image_size(self):
121
+ return self.config.image_size
122
+
123
+
124
+ class CLIPVisionTowerS2(CLIPVisionTower):
125
+ def __init__(self, vision_tower, args, delay_load=False):
126
+
127
+ self.s2_scales = getattr(args, "s2_scales", "336,672,1008")
128
+ self.s2_scales = list(map(int, self.s2_scales.split(",")))
129
+ self.s2_scales.sort()
130
+ self.s2_split_size = self.s2_scales[0]
131
+ self.s2_image_size = self.s2_scales[-1]
132
+
133
+ super().__init__(vision_tower, args, delay_load)
134
+
135
+ # change resize/crop size in preprocessing to the largest image size in s2_scale
136
+ if not delay_load or getattr(args, "unfreeze_mm_vision_tower", False):
137
+ self.image_processor.size["shortest_edge"] = self.s2_image_size
138
+ self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size
139
+
140
+ def load_model(self, device_map=None):
141
+ if self.is_loaded:
142
+ print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
143
+ return
144
+
145
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
146
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
147
+ self.vision_tower.requires_grad_(False)
148
+
149
+ self.image_processor.size["shortest_edge"] = self.s2_image_size
150
+ self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size
151
+
152
+ self.is_loaded = True
153
+
154
+ def forward_feature(self, images):
155
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
156
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
157
+ return image_features
158
+
159
+ def forward(self, images):
160
+ if type(images) is list:
161
+ image_features = []
162
+ for image in images:
163
+ image_feature = multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True)
164
+ image_features.append(image_feature)
165
+ else:
166
+ image_features = multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True)
167
+
168
+ return image_features
169
+
170
+ @property
171
+ def hidden_size(self):
172
+ return self.config.hidden_size * len(self.s2_scales)
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
2
+ from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer
3
+ from .factory import list_models, add_model_config, get_model_config, load_checkpoint
4
+ from .loss import ClipLoss
5
+ from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
6
+ from .openai import load_openai_model, list_openai_models
7
+ from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
8
+ from .tokenizer import SimpleTokenizer, tokenize
9
+ from .transform import image_transform
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/constants.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/eva_vit_model.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Adapted from https://github.com/microsoft/unilm/tree/master/beit
3
+ # --------------------------------------------------------
4
+ import math
5
+ import os
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ try:
11
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
12
+ except:
13
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
14
+
15
+ from .transformer import PatchDropout
16
+ from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
17
+
18
+ if os.getenv("ENV_TYPE") == "deepspeed":
19
+ try:
20
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
21
+ except:
22
+ from torch.utils.checkpoint import checkpoint
23
+ else:
24
+ from torch.utils.checkpoint import checkpoint
25
+
26
+ try:
27
+ import xformers.ops as xops
28
+ except ImportError:
29
+ xops = None
30
+ # print("Please 'pip install xformers'")
31
+
32
+
33
+ class DropPath(nn.Module):
34
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
35
+
36
+ def __init__(self, drop_prob=None):
37
+ super(DropPath, self).__init__()
38
+ self.drop_prob = drop_prob
39
+
40
+ def forward(self, x):
41
+ return drop_path(x, self.drop_prob, self.training)
42
+
43
+ def extra_repr(self) -> str:
44
+ return "p={}".format(self.drop_prob)
45
+
46
+
47
+ class Mlp(nn.Module):
48
+ def __init__(
49
+ self,
50
+ in_features,
51
+ hidden_features=None,
52
+ out_features=None,
53
+ act_layer=nn.GELU,
54
+ norm_layer=nn.LayerNorm,
55
+ drop=0.0,
56
+ subln=False,
57
+ ):
58
+ super().__init__()
59
+ out_features = out_features or in_features
60
+ hidden_features = hidden_features or in_features
61
+ self.fc1 = nn.Linear(in_features, hidden_features)
62
+ self.act = act_layer()
63
+
64
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
65
+
66
+ self.fc2 = nn.Linear(hidden_features, out_features)
67
+ self.drop = nn.Dropout(drop)
68
+
69
+ def forward(self, x):
70
+ x = self.fc1(x)
71
+ x = self.act(x)
72
+ # x = self.drop(x)
73
+ # commit this for the orignal BERT implement
74
+ x = self.ffn_ln(x)
75
+
76
+ x = self.fc2(x)
77
+ x = self.drop(x)
78
+ return x
79
+
80
+
81
+ class SwiGLU(nn.Module):
82
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.0, norm_layer=nn.LayerNorm, subln=False):
83
+ super().__init__()
84
+ out_features = out_features or in_features
85
+ hidden_features = hidden_features or in_features
86
+
87
+ self.w1 = nn.Linear(in_features, hidden_features)
88
+ self.w2 = nn.Linear(in_features, hidden_features)
89
+
90
+ self.act = act_layer()
91
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
92
+ self.w3 = nn.Linear(hidden_features, out_features)
93
+
94
+ self.drop = nn.Dropout(drop)
95
+
96
+ def forward(self, x):
97
+ x1 = self.w1(x)
98
+ x2 = self.w2(x)
99
+ hidden = self.act(x1) * x2
100
+ x = self.ffn_ln(hidden)
101
+ x = self.w3(x)
102
+ x = self.drop(x)
103
+ return x
104
+
105
+
106
+ class Attention(nn.Module):
107
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
108
+ super().__init__()
109
+ self.num_heads = num_heads
110
+ head_dim = dim // num_heads
111
+ if attn_head_dim is not None:
112
+ head_dim = attn_head_dim
113
+ all_head_dim = head_dim * self.num_heads
114
+ self.scale = qk_scale or head_dim**-0.5
115
+
116
+ self.subln = subln
117
+ if self.subln:
118
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
119
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
120
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
121
+ else:
122
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
123
+
124
+ if qkv_bias:
125
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
126
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
127
+ else:
128
+ self.q_bias = None
129
+ self.v_bias = None
130
+
131
+ if window_size:
132
+ self.window_size = window_size
133
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
134
+ self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
135
+ # cls to token & token 2 cls & cls to cls
136
+
137
+ # get pair-wise relative position index for each token inside the window
138
+ coords_h = torch.arange(window_size[0])
139
+ coords_w = torch.arange(window_size[1])
140
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
141
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
142
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
143
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
144
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
145
+ relative_coords[:, :, 1] += window_size[1] - 1
146
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
147
+ relative_position_index = torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
148
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
149
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
150
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
151
+ relative_position_index[0, 0] = self.num_relative_distance - 1
152
+
153
+ self.register_buffer("relative_position_index", relative_position_index)
154
+ else:
155
+ self.window_size = None
156
+ self.relative_position_bias_table = None
157
+ self.relative_position_index = None
158
+
159
+ self.attn_drop = nn.Dropout(attn_drop)
160
+ self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
161
+ # self.proj = nn.Linear(all_head_dim, all_head_dim)
162
+ self.proj = nn.Linear(all_head_dim, dim)
163
+ self.proj_drop = nn.Dropout(proj_drop)
164
+ self.xattn = xattn
165
+ self.xattn_drop = attn_drop
166
+
167
+ self.rope = rope
168
+
169
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
170
+ B, N, C = x.shape
171
+ if self.subln:
172
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
173
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
174
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
175
+
176
+ q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
177
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
178
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
179
+ else:
180
+
181
+ qkv_bias = None
182
+ if self.q_bias is not None:
183
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
184
+
185
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
186
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
187
+ q, k, v = qkv[0], qkv[1], qkv[2]
188
+
189
+ if self.rope:
190
+ # slightly fast impl
191
+ q_t = q[:, :, 1:, :]
192
+ ro_q_t = self.rope(q_t)
193
+ q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
194
+
195
+ k_t = k[:, :, 1:, :]
196
+ ro_k_t = self.rope(k_t)
197
+ k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
198
+
199
+ if self.xattn:
200
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
201
+ k = k.permute(0, 2, 1, 3)
202
+ v = v.permute(0, 2, 1, 3)
203
+
204
+ x = xops.memory_efficient_attention(
205
+ q,
206
+ k,
207
+ v,
208
+ p=self.xattn_drop,
209
+ scale=self.scale,
210
+ )
211
+ x = x.reshape(B, N, -1)
212
+ x = self.inner_attn_ln(x)
213
+ x = self.proj(x)
214
+ x = self.proj_drop(x)
215
+ else:
216
+ q = q * self.scale
217
+ attn = q @ k.transpose(-2, -1)
218
+
219
+ if self.relative_position_bias_table is not None:
220
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
221
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
222
+ attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
223
+
224
+ if rel_pos_bias is not None:
225
+ attn = attn + rel_pos_bias.type_as(attn)
226
+
227
+ if attn_mask is not None:
228
+ attn_mask = attn_mask.bool()
229
+ attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
230
+
231
+ attn = attn.softmax(dim=-1)
232
+ attn = self.attn_drop(attn)
233
+
234
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
235
+ x = self.inner_attn_ln(x)
236
+ x = self.proj(x)
237
+ x = self.proj_drop(x)
238
+ return x
239
+
240
+
241
+ class Block(nn.Module):
242
+
243
+ def __init__(
244
+ self,
245
+ dim,
246
+ num_heads,
247
+ mlp_ratio=4.0,
248
+ qkv_bias=False,
249
+ qk_scale=None,
250
+ drop=0.0,
251
+ attn_drop=0.0,
252
+ drop_path=0.0,
253
+ init_values=None,
254
+ act_layer=nn.GELU,
255
+ norm_layer=nn.LayerNorm,
256
+ window_size=None,
257
+ attn_head_dim=None,
258
+ xattn=False,
259
+ rope=None,
260
+ postnorm=False,
261
+ subln=False,
262
+ naiveswiglu=False,
263
+ ):
264
+ super().__init__()
265
+ self.norm1 = norm_layer(dim)
266
+ self.attn = Attention(
267
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim, xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer
268
+ )
269
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
270
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
271
+ self.norm2 = norm_layer(dim)
272
+ mlp_hidden_dim = int(dim * mlp_ratio)
273
+
274
+ if naiveswiglu:
275
+ self.mlp = SwiGLU(
276
+ in_features=dim,
277
+ hidden_features=mlp_hidden_dim,
278
+ subln=subln,
279
+ norm_layer=norm_layer,
280
+ )
281
+ else:
282
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, subln=subln, drop=drop)
283
+
284
+ if init_values is not None and init_values > 0:
285
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
286
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
287
+ else:
288
+ self.gamma_1, self.gamma_2 = None, None
289
+
290
+ self.postnorm = postnorm
291
+
292
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
293
+ if self.gamma_1 is None:
294
+ if self.postnorm:
295
+ x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
296
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
297
+ else:
298
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
299
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
300
+ else:
301
+ if self.postnorm:
302
+ x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
303
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
304
+ else:
305
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
306
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
307
+ return x
308
+
309
+
310
+ class PatchEmbed(nn.Module):
311
+ """Image to Patch Embedding"""
312
+
313
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
314
+ super().__init__()
315
+ img_size = to_2tuple(img_size)
316
+ patch_size = to_2tuple(patch_size)
317
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
318
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
319
+ self.img_size = img_size
320
+ self.patch_size = patch_size
321
+ self.num_patches = num_patches
322
+
323
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
324
+
325
+ def forward(self, x, **kwargs):
326
+ B, C, H, W = x.shape
327
+ # FIXME look at relaxing size constraints
328
+ assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
329
+ x = self.proj(x).flatten(2).transpose(1, 2)
330
+ return x
331
+
332
+
333
+ class RelativePositionBias(nn.Module):
334
+
335
+ def __init__(self, window_size, num_heads):
336
+ super().__init__()
337
+ self.window_size = window_size
338
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
339
+ self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
340
+ # cls to token & token 2 cls & cls to cls
341
+
342
+ # get pair-wise relative position index for each token inside the window
343
+ coords_h = torch.arange(window_size[0])
344
+ coords_w = torch.arange(window_size[1])
345
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
346
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
347
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
348
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
349
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
350
+ relative_coords[:, :, 1] += window_size[1] - 1
351
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
352
+ relative_position_index = torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
353
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
354
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
355
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
356
+ relative_position_index[0, 0] = self.num_relative_distance - 1
357
+
358
+ self.register_buffer("relative_position_index", relative_position_index)
359
+
360
+ def forward(self):
361
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
362
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
363
+
364
+
365
+ class EVAVisionTransformer(nn.Module):
366
+ """Vision Transformer with support for patch or hybrid CNN input stage"""
367
+
368
+ def __init__(
369
+ self,
370
+ img_size=224,
371
+ patch_size=16,
372
+ in_chans=3,
373
+ num_classes=1000,
374
+ embed_dim=768,
375
+ depth=12,
376
+ num_heads=12,
377
+ mlp_ratio=4.0,
378
+ qkv_bias=False,
379
+ qk_scale=None,
380
+ drop_rate=0.0,
381
+ attn_drop_rate=0.0,
382
+ drop_path_rate=0.0,
383
+ norm_layer=nn.LayerNorm,
384
+ init_values=None,
385
+ patch_dropout=0.0,
386
+ use_abs_pos_emb=True,
387
+ use_rel_pos_bias=False,
388
+ use_shared_rel_pos_bias=False,
389
+ rope=False,
390
+ use_mean_pooling=True,
391
+ init_scale=0.001,
392
+ grad_checkpointing=False,
393
+ xattn=False,
394
+ postnorm=False,
395
+ pt_hw_seq_len=16,
396
+ intp_freq=False,
397
+ naiveswiglu=False,
398
+ subln=False,
399
+ ):
400
+ super().__init__()
401
+ self.image_size = img_size
402
+ self.num_classes = num_classes
403
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
404
+
405
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
406
+ num_patches = self.patch_embed.num_patches
407
+
408
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
409
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
410
+ if use_abs_pos_emb:
411
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
412
+ else:
413
+ self.pos_embed = None
414
+ self.pos_drop = nn.Dropout(p=drop_rate)
415
+
416
+ if use_shared_rel_pos_bias:
417
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
418
+ else:
419
+ self.rel_pos_bias = None
420
+
421
+ if rope:
422
+ half_head_dim = embed_dim // num_heads // 2
423
+ hw_seq_len = img_size // patch_size
424
+ self.rope = VisionRotaryEmbeddingFast(
425
+ dim=half_head_dim,
426
+ pt_seq_len=pt_hw_seq_len,
427
+ ft_seq_len=hw_seq_len if intp_freq else None,
428
+ # patch_dropout=patch_dropout
429
+ )
430
+ else:
431
+ self.rope = None
432
+
433
+ self.naiveswiglu = naiveswiglu
434
+
435
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
436
+ self.use_rel_pos_bias = use_rel_pos_bias
437
+ self.blocks = nn.ModuleList(
438
+ [
439
+ Block(
440
+ dim=embed_dim,
441
+ num_heads=num_heads,
442
+ mlp_ratio=mlp_ratio,
443
+ qkv_bias=qkv_bias,
444
+ qk_scale=qk_scale,
445
+ drop=drop_rate,
446
+ attn_drop=attn_drop_rate,
447
+ drop_path=dpr[i],
448
+ norm_layer=norm_layer,
449
+ init_values=init_values,
450
+ window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
451
+ xattn=xattn,
452
+ rope=self.rope,
453
+ postnorm=postnorm,
454
+ subln=subln,
455
+ naiveswiglu=naiveswiglu,
456
+ )
457
+ for i in range(depth)
458
+ ]
459
+ )
460
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
461
+ self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
462
+ self.head = nn.Linear(embed_dim, num_classes, bias=qkv_bias) if num_classes > 0 else nn.Identity()
463
+
464
+ if self.pos_embed is not None:
465
+ trunc_normal_(self.pos_embed, std=0.02)
466
+
467
+ trunc_normal_(self.cls_token, std=0.02)
468
+
469
+ self.apply(self._init_weights)
470
+ self.fix_init_weight()
471
+
472
+ if isinstance(self.head, nn.Linear):
473
+ trunc_normal_(self.head.weight, std=0.02)
474
+ self.head.weight.data.mul_(init_scale)
475
+ if self.head.bias is not None:
476
+ self.head.bias.data.mul_(init_scale)
477
+
478
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
479
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity()
480
+
481
+ self.grad_checkpointing = grad_checkpointing
482
+
483
+ def fix_init_weight(self):
484
+ def rescale(param, layer_id):
485
+ param.div_(math.sqrt(2.0 * layer_id))
486
+
487
+ for layer_id, layer in enumerate(self.blocks):
488
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
489
+ if self.naiveswiglu:
490
+ rescale(layer.mlp.w3.weight.data, layer_id + 1)
491
+ else:
492
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
493
+
494
+ def get_cast_dtype(self) -> torch.dtype:
495
+ return self.blocks[0].mlp.fc2.weight.dtype
496
+
497
+ def _init_weights(self, m):
498
+ if isinstance(m, nn.Linear):
499
+ trunc_normal_(m.weight, std=0.02)
500
+ if m.bias is not None:
501
+ nn.init.constant_(m.bias, 0)
502
+ elif isinstance(m, nn.LayerNorm):
503
+ nn.init.constant_(m.bias, 0)
504
+ nn.init.constant_(m.weight, 1.0)
505
+
506
+ def get_num_layers(self):
507
+ return len(self.blocks)
508
+
509
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
510
+ assert unlocked_groups == 0, "partial locking not currently supported for this model"
511
+ for param in self.parameters():
512
+ param.requires_grad = False
513
+
514
+ @torch.jit.ignore
515
+ def set_grad_checkpointing(self, enable=True):
516
+ self.grad_checkpointing = enable
517
+
518
+ @torch.jit.ignore
519
+ def no_weight_decay(self):
520
+ return {"pos_embed", "cls_token"}
521
+
522
+ def get_classifier(self):
523
+ return self.head
524
+
525
+ def reset_classifier(self, num_classes, global_pool=""):
526
+ self.num_classes = num_classes
527
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
528
+
529
+ def forward_features(self, x, return_all_features=False):
530
+
531
+ x = self.patch_embed(x)
532
+ batch_size, seq_len, _ = x.size()
533
+
534
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
535
+ x = torch.cat((cls_tokens, x), dim=1)
536
+ if self.pos_embed is not None:
537
+ x = x + self.pos_embed
538
+ x = self.pos_drop(x)
539
+
540
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
541
+ # if os.getenv("RoPE") == "1":
542
+ # if self.training and not isinstance(self.patch_dropout, nn.Identity):
543
+ # x, patch_indices_keep = self.patch_dropout(x)
544
+ # self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
545
+ # else:
546
+ # self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
547
+ # x = self.patch_dropout(x)
548
+ # else:
549
+ x = self.patch_dropout(x)
550
+
551
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
552
+ for blk in self.blocks:
553
+ if self.grad_checkpointing:
554
+ x = checkpoint(blk, x, (rel_pos_bias,))
555
+ else:
556
+ x = blk(x, rel_pos_bias=rel_pos_bias)
557
+
558
+ if not return_all_features:
559
+ x = self.norm(x)
560
+ if self.fc_norm is not None:
561
+ return self.fc_norm(x.mean(1))
562
+ else:
563
+ return x[:, 0]
564
+ return x
565
+
566
+ def forward(self, x, return_all_features=False):
567
+ if return_all_features:
568
+ return self.forward_features(x, return_all_features)
569
+ x = self.forward_features(x)
570
+ x = self.head(x)
571
+ return x
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/factory.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import re
6
+ from copy import deepcopy
7
+ from pathlib import Path
8
+ from typing import Optional, Tuple, Union, Dict, Any
9
+ import torch
10
+
11
+ try:
12
+ import deepspeed
13
+ except ImportError:
14
+ deepspeed = None
15
+
16
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
17
+ from .model import CLIP, CustomCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict, get_cast_dtype
18
+ from .openai import load_openai_model
19
+ from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
20
+ from .transform import image_transform
21
+ from .tokenizer import HFTokenizer, tokenize
22
+ from .utils import resize_clip_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed, resize_eva_pos_embed
23
+
24
+
25
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
26
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
27
+
28
+
29
+ def _natural_key(string_):
30
+ return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
31
+
32
+
33
+ def _rescan_model_configs():
34
+ global _MODEL_CONFIGS
35
+
36
+ config_ext = (".json",)
37
+ config_files = []
38
+ for config_path in _MODEL_CONFIG_PATHS:
39
+ if config_path.is_file() and config_path.suffix in config_ext:
40
+ config_files.append(config_path)
41
+ elif config_path.is_dir():
42
+ for ext in config_ext:
43
+ config_files.extend(config_path.glob(f"*{ext}"))
44
+
45
+ for cf in config_files:
46
+ with open(cf, "r", encoding="utf8") as f:
47
+ model_cfg = json.load(f)
48
+ if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")):
49
+ _MODEL_CONFIGS[cf.stem] = model_cfg
50
+
51
+ _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])))
52
+
53
+
54
+ _rescan_model_configs() # initial populate of model config registry
55
+
56
+
57
+ def list_models():
58
+ """enumerate available model architectures based on config files"""
59
+ return list(_MODEL_CONFIGS.keys())
60
+
61
+
62
+ def add_model_config(path):
63
+ """add model config path or file and update registry"""
64
+ if not isinstance(path, Path):
65
+ path = Path(path)
66
+ _MODEL_CONFIG_PATHS.append(path)
67
+ _rescan_model_configs()
68
+
69
+
70
+ def get_model_config(model_name):
71
+ if model_name in _MODEL_CONFIGS:
72
+ return deepcopy(_MODEL_CONFIGS[model_name])
73
+ else:
74
+ return None
75
+
76
+
77
+ def get_tokenizer(model_name):
78
+ config = get_model_config(model_name)
79
+ tokenizer = HFTokenizer(config["text_cfg"]["hf_tokenizer_name"]) if "hf_tokenizer_name" in config["text_cfg"] else tokenize
80
+ return tokenizer
81
+
82
+
83
+ # loading openai CLIP weights when is_openai=True for training
84
+ def load_state_dict(checkpoint_path: str, map_location: str = "cpu", model_key: str = "model|module|state_dict", is_openai: bool = False, skip_list: list = []):
85
+ if is_openai:
86
+ model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
87
+ state_dict = model.state_dict()
88
+ for key in ["input_resolution", "context_length", "vocab_size"]:
89
+ state_dict.pop(key, None)
90
+ else:
91
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
92
+ for mk in model_key.split("|"):
93
+ if isinstance(checkpoint, dict) and mk in checkpoint:
94
+ state_dict = checkpoint[mk]
95
+ break
96
+ else:
97
+ state_dict = checkpoint
98
+ if next(iter(state_dict.items()))[0].startswith("module"):
99
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
100
+
101
+ for k in skip_list:
102
+ if k in list(state_dict.keys()):
103
+ logging.info(f"Removing key {k} from pretrained checkpoint")
104
+ del state_dict[k]
105
+
106
+ if os.getenv("RoPE") == "1":
107
+ for k in list(state_dict.keys()):
108
+ if "freqs_cos" in k or "freqs_sin" in k:
109
+ del state_dict[k]
110
+ return state_dict
111
+
112
+
113
+ def load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=True):
114
+ state_dict = load_state_dict(checkpoint_path, model_key=model_key, is_openai=False)
115
+ # detect old format and make compatible with new format
116
+ if "positional_embedding" in state_dict and not hasattr(model, "positional_embedding"):
117
+ state_dict = convert_to_custom_text_state_dict(state_dict)
118
+ if "text.logit_scale" in state_dict and hasattr(model, "logit_scale"):
119
+ state_dict["logit_scale"] = state_dict["text.logit_scale"]
120
+ del state_dict["text.logit_scale"]
121
+
122
+ # resize_clip_pos_embed for CLIP and open CLIP
123
+ if "visual.positional_embedding" in state_dict:
124
+ resize_clip_pos_embed(state_dict, model)
125
+ # specified to eva_vit_model
126
+ elif "visual.pos_embed" in state_dict:
127
+ resize_evaclip_pos_embed(state_dict, model)
128
+
129
+ # resize_clip_pos_embed(state_dict, model)
130
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
131
+ logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}")
132
+ return incompatible_keys
133
+
134
+
135
+ def load_clip_visual_state_dict(checkpoint_path: str, map_location: str = "cpu", is_openai: bool = False, skip_list: list = []):
136
+ state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
137
+
138
+ for k in list(state_dict.keys()):
139
+ if not k.startswith("visual."):
140
+ del state_dict[k]
141
+ for k in list(state_dict.keys()):
142
+ if k.startswith("visual."):
143
+ new_k = k[7:]
144
+ state_dict[new_k] = state_dict[k]
145
+ del state_dict[k]
146
+ return state_dict
147
+
148
+
149
+ def load_clip_text_state_dict(checkpoint_path: str, map_location: str = "cpu", is_openai: bool = False, skip_list: list = []):
150
+ state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
151
+
152
+ for k in list(state_dict.keys()):
153
+ if k.startswith("visual."):
154
+ del state_dict[k]
155
+ return state_dict
156
+
157
+
158
+ def get_pretrained_tag(pretrained_model):
159
+ pretrained_model = pretrained_model.lower()
160
+ if "laion" in pretrained_model or "open_clip" in pretrained_model:
161
+ return "open_clip"
162
+ elif "openai" in pretrained_model:
163
+ return "clip"
164
+ elif "eva" in pretrained_model and "clip" in pretrained_model:
165
+ return "eva_clip"
166
+ else:
167
+ return "other"
168
+
169
+
170
+ def load_zero_partitions(model, state_dict, is_deepspeed_zero3_enabled, pretrained_model_path, ignore_mismatched_sizes=False):
171
+ """
172
+ adept from pytorch lightning and transformers
173
+ with deepspeed.zero.Init():
174
+ model = MyModel()
175
+ state_dict = torch.load(model_path, map_location="cpu")
176
+ load_zero_partitions(model, prefix="")
177
+ """
178
+
179
+ # because zero3 puts placeholders in model params, this context
180
+ # manager gathers (unpartitions) the params of the current layer, then loads from
181
+ # the state dict and then re-partitions them again
182
+ model_state_dict = model.state_dict()
183
+ expected_keys = list(model_state_dict.keys())
184
+ loaded_keys = list(state_dict.keys())
185
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
186
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
187
+
188
+ # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
189
+ # matching the weights in the model.
190
+ mismatched_keys = []
191
+ if ignore_mismatched_sizes:
192
+ for checkpoint_key in loaded_keys:
193
+ model_key = checkpoint_key
194
+
195
+ if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
196
+ mismatched_keys.append((checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape))
197
+ del state_dict[checkpoint_key]
198
+ # copy state_dict so _load_from_state_dict can modify it
199
+ metadata = getattr(state_dict, "_metadata", None)
200
+ state_dict = state_dict.copy()
201
+ if metadata is not None:
202
+ state_dict._metadata = metadata
203
+
204
+ error_msgs = []
205
+
206
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
207
+ # so we need to apply the function recursively.
208
+ def load(module, prefix=""):
209
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
210
+ args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
211
+ if is_deepspeed_zero3_enabled:
212
+ # because zero3 puts placeholders in model params, this context
213
+ # manager gathers (unpartitions) the params of the current layer, then loads from
214
+ # the state dict and then re-partitions them again
215
+ with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
216
+ if torch.distributed.get_rank() == 0:
217
+ module._load_from_state_dict(*args)
218
+ else:
219
+ module._load_from_state_dict(*args)
220
+
221
+ for name, child in module._modules.items():
222
+ if child is not None:
223
+ load(child, prefix + name + ".")
224
+
225
+ # Make sure we are able to load base models as well as derived models (with heads)
226
+ start_prefix = ""
227
+ model_to_load = model
228
+ load(model_to_load, prefix=start_prefix)
229
+ del state_dict
230
+ if len(error_msgs) > 0:
231
+ error_msg = "\n\t".join(error_msgs)
232
+ if "size mismatch" in error_msg:
233
+ error_msg += "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
234
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
235
+ if len(unexpected_keys) > 0:
236
+ logging.warning(
237
+ f"Some weights of the model checkpoint at {pretrained_model_path} were not used when"
238
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
239
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
240
+ " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
241
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
242
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
243
+ " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
244
+ )
245
+ else:
246
+ logging.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
247
+ if len(missing_keys) > 0:
248
+ logging.warning(
249
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
250
+ f" {pretrained_model_path} and are newly initialized: {missing_keys}\nYou should probably"
251
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
252
+ )
253
+ elif len(mismatched_keys) == 0:
254
+ logging.info(
255
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
256
+ f" {pretrained_model_path}.\nIf your task is similar to the task the model of the checkpoint"
257
+ f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
258
+ " training."
259
+ )
260
+ if len(mismatched_keys) > 0:
261
+ mismatched_warning = "\n".join([f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" for key, shape1, shape2 in mismatched_keys])
262
+ logging.warning(
263
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
264
+ f" {pretrained_model_path} and are newly initialized because the shapes did not"
265
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
266
+ " to use it for predictions and inference."
267
+ )
268
+
269
+
270
+ def load_pretrained_checkpoint(model, visual_checkpoint_path, text_checkpoint_path, strict=True, visual_model=None, text_model=None, model_key="model|module|state_dict", skip_list=[]):
271
+ visual_tag = get_pretrained_tag(visual_model)
272
+ text_tag = get_pretrained_tag(text_model)
273
+
274
+ logging.info(f"num of model state_dict keys: {len(model.state_dict().keys())}")
275
+ visual_incompatible_keys, text_incompatible_keys = None, None
276
+ if visual_checkpoint_path:
277
+ if visual_tag == "eva_clip" or visual_tag == "open_clip":
278
+ visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=False, skip_list=skip_list)
279
+ elif visual_tag == "clip":
280
+ visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list)
281
+ else:
282
+ visual_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
283
+
284
+ # resize_clip_pos_embed for CLIP and open CLIP
285
+ if "positional_embedding" in visual_state_dict:
286
+ resize_visual_pos_embed(visual_state_dict, model)
287
+ # specified to EVA model
288
+ elif "pos_embed" in visual_state_dict:
289
+ resize_eva_pos_embed(visual_state_dict, model)
290
+
291
+ visual_incompatible_keys = model.visual.load_state_dict(visual_state_dict, strict=strict)
292
+ logging.info(f"num of loaded visual_state_dict keys: {len(visual_state_dict.keys())}")
293
+ logging.info(f"visual_incompatible_keys.missing_keys: {visual_incompatible_keys.missing_keys}")
294
+
295
+ if text_checkpoint_path:
296
+ if text_tag == "eva_clip" or text_tag == "open_clip":
297
+ text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=False, skip_list=skip_list)
298
+ elif text_tag == "clip":
299
+ text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=True, skip_list=skip_list)
300
+ else:
301
+ text_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
302
+
303
+ text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict)
304
+
305
+ logging.info(f"num of loaded text_state_dict keys: {len(text_state_dict.keys())}")
306
+ logging.info(f"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}")
307
+
308
+ return visual_incompatible_keys, text_incompatible_keys
309
+
310
+
311
+ def create_model(
312
+ model_name: str,
313
+ pretrained: Optional[str] = None,
314
+ precision: str = "fp32",
315
+ device: Union[str, torch.device] = "cpu",
316
+ jit: bool = False,
317
+ force_quick_gelu: bool = False,
318
+ force_custom_clip: bool = False,
319
+ force_patch_dropout: Optional[float] = None,
320
+ pretrained_image: str = "",
321
+ pretrained_text: str = "",
322
+ pretrained_hf: bool = True,
323
+ pretrained_visual_model: str = None,
324
+ pretrained_text_model: str = None,
325
+ cache_dir: Optional[str] = None,
326
+ skip_list: list = [],
327
+ ):
328
+ model_name = model_name.replace("/", "-") # for callers using old naming with / in ViT names
329
+ if isinstance(device, str):
330
+ device = torch.device(device)
331
+
332
+ if pretrained and pretrained.lower() == "openai":
333
+ logging.info(f"Loading pretrained {model_name} from OpenAI.")
334
+ model = load_openai_model(
335
+ model_name,
336
+ precision=precision,
337
+ device=device,
338
+ jit=jit,
339
+ cache_dir=cache_dir,
340
+ )
341
+ else:
342
+ model_cfg = get_model_config(model_name)
343
+ if model_cfg is not None:
344
+ logging.info(f"Loaded {model_name} model config.")
345
+ else:
346
+ logging.error(f"Model config for {model_name} not found; available models {list_models()}.")
347
+ raise RuntimeError(f"Model config for {model_name} not found.")
348
+
349
+ if "rope" in model_cfg.get("vision_cfg", {}):
350
+ if model_cfg["vision_cfg"]["rope"]:
351
+ os.environ["RoPE"] = "1"
352
+ else:
353
+ os.environ["RoPE"] = "0"
354
+
355
+ if force_quick_gelu:
356
+ # override for use of QuickGELU on non-OpenAI transformer models
357
+ model_cfg["quick_gelu"] = True
358
+
359
+ if force_patch_dropout is not None:
360
+ # override the default patch dropout value
361
+ model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
362
+
363
+ cast_dtype = get_cast_dtype(precision)
364
+ custom_clip = model_cfg.pop("custom_text", False) or force_custom_clip or ("hf_model_name" in model_cfg["text_cfg"])
365
+
366
+ if custom_clip:
367
+ if "hf_model_name" in model_cfg.get("text_cfg", {}):
368
+ model_cfg["text_cfg"]["hf_model_pretrained"] = pretrained_hf
369
+ model = CustomCLIP(**model_cfg, cast_dtype=cast_dtype)
370
+ else:
371
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
372
+
373
+ pretrained_cfg = {}
374
+ if pretrained:
375
+ checkpoint_path = ""
376
+ pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
377
+ if pretrained_cfg:
378
+ checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
379
+ elif os.path.exists(pretrained):
380
+ checkpoint_path = pretrained
381
+
382
+ if checkpoint_path:
383
+ logging.info(f"Loading pretrained {model_name} weights ({pretrained}).")
384
+ load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=False)
385
+ else:
386
+ error_str = f"Pretrained weights ({pretrained}) not found for model {model_name}." f"Available pretrained tags ({list_pretrained_tags_by_model(model_name)}."
387
+ logging.warning(error_str)
388
+ raise RuntimeError(error_str)
389
+ else:
390
+ visual_checkpoint_path = ""
391
+ text_checkpoint_path = ""
392
+
393
+ if pretrained_image:
394
+ pretrained_visual_model = pretrained_visual_model.replace("/", "-") # for callers using old naming with / in ViT names
395
+ pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image)
396
+ if "timm_model_name" in model_cfg.get("vision_cfg", {}):
397
+ # pretrained weight loading for timm models set via vision_cfg
398
+ model_cfg["vision_cfg"]["timm_model_pretrained"] = True
399
+ elif pretrained_image_cfg:
400
+ visual_checkpoint_path = download_pretrained(pretrained_image_cfg, cache_dir=cache_dir)
401
+ elif os.path.exists(pretrained_image):
402
+ visual_checkpoint_path = pretrained_image
403
+ else:
404
+ logging.warning(f"Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.")
405
+ raise RuntimeError(f"Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.")
406
+
407
+ if pretrained_text:
408
+ pretrained_text_model = pretrained_text_model.replace("/", "-") # for callers using old naming with / in ViT names
409
+ pretrained_text_cfg = get_pretrained_cfg(pretrained_text_model, pretrained_text)
410
+ if pretrained_image_cfg:
411
+ text_checkpoint_path = download_pretrained(pretrained_text_cfg, cache_dir=cache_dir)
412
+ elif os.path.exists(pretrained_text):
413
+ text_checkpoint_path = pretrained_text
414
+ else:
415
+ logging.warning(f"Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.")
416
+ raise RuntimeError(f"Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.")
417
+
418
+ if visual_checkpoint_path:
419
+ logging.info(f"Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).")
420
+ if text_checkpoint_path:
421
+ logging.info(f"Loading pretrained {model_name}.text weights ({text_checkpoint_path}).")
422
+
423
+ if visual_checkpoint_path or text_checkpoint_path:
424
+ load_pretrained_checkpoint(model, visual_checkpoint_path, text_checkpoint_path, strict=False, visual_model=pretrained_visual_model, text_model=pretrained_text_model, model_key="model|module|state_dict", skip_list=skip_list)
425
+
426
+ if "fp16" in precision or "bf16" in precision:
427
+ logging.info(f"convert precision to {precision}")
428
+ model = model.to(torch.bfloat16) if "bf16" in precision else model.to(torch.float16)
429
+
430
+ # model.to(device=device)
431
+
432
+ # set image / mean metadata from pretrained_cfg if available, or use default
433
+ model.visual.image_mean = pretrained_cfg.get("mean", None) or OPENAI_DATASET_MEAN
434
+ model.visual.image_std = pretrained_cfg.get("std", None) or OPENAI_DATASET_STD
435
+
436
+ if jit:
437
+ model = torch.jit.script(model)
438
+
439
+ return model
440
+
441
+
442
+ def create_model_and_transforms(
443
+ model_name: str,
444
+ pretrained: Optional[str] = None,
445
+ precision: str = "fp32",
446
+ device: Union[str, torch.device] = "cpu",
447
+ jit: bool = False,
448
+ force_quick_gelu: bool = False,
449
+ force_custom_clip: bool = False,
450
+ force_patch_dropout: Optional[float] = None,
451
+ pretrained_image: str = "",
452
+ pretrained_text: str = "",
453
+ pretrained_hf: bool = True,
454
+ pretrained_visual_model: str = None,
455
+ pretrained_text_model: str = None,
456
+ image_mean: Optional[Tuple[float, ...]] = None,
457
+ image_std: Optional[Tuple[float, ...]] = None,
458
+ cache_dir: Optional[str] = None,
459
+ skip_list: list = [],
460
+ ):
461
+ model = create_model(
462
+ model_name,
463
+ pretrained,
464
+ precision=precision,
465
+ device=device,
466
+ jit=jit,
467
+ force_quick_gelu=force_quick_gelu,
468
+ force_custom_clip=force_custom_clip,
469
+ force_patch_dropout=force_patch_dropout,
470
+ pretrained_image=pretrained_image,
471
+ pretrained_text=pretrained_text,
472
+ pretrained_hf=pretrained_hf,
473
+ pretrained_visual_model=pretrained_visual_model,
474
+ pretrained_text_model=pretrained_text_model,
475
+ cache_dir=cache_dir,
476
+ skip_list=skip_list,
477
+ )
478
+
479
+ image_mean = image_mean or getattr(model.visual, "image_mean", None)
480
+ image_std = image_std or getattr(model.visual, "image_std", None)
481
+ preprocess_train = image_transform(model.visual.image_size, is_train=True, mean=image_mean, std=image_std)
482
+ preprocess_val = image_transform(model.visual.image_size, is_train=False, mean=image_mean, std=image_std)
483
+
484
+ return model, preprocess_train, preprocess_val
485
+
486
+
487
+ def create_model_from_pretrained(
488
+ model_name: str,
489
+ pretrained: str,
490
+ precision: str = "fp32",
491
+ device: Union[str, torch.device] = "cpu",
492
+ jit: bool = False,
493
+ force_quick_gelu: bool = False,
494
+ force_custom_clip: bool = False,
495
+ force_patch_dropout: Optional[float] = None,
496
+ return_transform: bool = True,
497
+ image_mean: Optional[Tuple[float, ...]] = None,
498
+ image_std: Optional[Tuple[float, ...]] = None,
499
+ cache_dir: Optional[str] = None,
500
+ is_frozen: bool = False,
501
+ ):
502
+ if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):
503
+ raise RuntimeError(f"{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}." f" Use open_clip.list_pretrained() to find one.")
504
+
505
+ model = create_model(
506
+ model_name,
507
+ pretrained,
508
+ precision=precision,
509
+ device=device,
510
+ jit=jit,
511
+ force_quick_gelu=force_quick_gelu,
512
+ force_custom_clip=force_custom_clip,
513
+ force_patch_dropout=force_patch_dropout,
514
+ cache_dir=cache_dir,
515
+ )
516
+
517
+ if is_frozen:
518
+ for param in model.parameters():
519
+ param.requires_grad = False
520
+
521
+ if not return_transform:
522
+ return model
523
+
524
+ image_mean = image_mean or getattr(model.visual, "image_mean", None)
525
+ image_std = image_std or getattr(model.visual, "image_std", None)
526
+ preprocess = image_transform(model.visual.image_size, is_train=False, mean=image_mean, std=image_std)
527
+
528
+ return model, preprocess
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_configs.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HF architecture dict:
2
+ arch_dict = {
3
+ # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4
+ "roberta": {
5
+ "config_names": {
6
+ "context_length": "max_position_embeddings",
7
+ "vocab_size": "vocab_size",
8
+ "width": "hidden_size",
9
+ "heads": "num_attention_heads",
10
+ "layers": "num_hidden_layers",
11
+ "layer_attr": "layer",
12
+ "token_embeddings_attr": "embeddings",
13
+ },
14
+ "pooler": "mean_pooler",
15
+ },
16
+ # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17
+ "xlm-roberta": {
18
+ "config_names": {
19
+ "context_length": "max_position_embeddings",
20
+ "vocab_size": "vocab_size",
21
+ "width": "hidden_size",
22
+ "heads": "num_attention_heads",
23
+ "layers": "num_hidden_layers",
24
+ "layer_attr": "layer",
25
+ "token_embeddings_attr": "embeddings",
26
+ },
27
+ "pooler": "mean_pooler",
28
+ },
29
+ # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30
+ "mt5": {
31
+ "config_names": {
32
+ # unlimited seqlen
33
+ # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34
+ # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35
+ "context_length": "",
36
+ "vocab_size": "vocab_size",
37
+ "width": "d_model",
38
+ "heads": "num_heads",
39
+ "layers": "num_layers",
40
+ "layer_attr": "block",
41
+ "token_embeddings_attr": "embed_tokens",
42
+ },
43
+ "pooler": "mean_pooler",
44
+ },
45
+ "bert": {
46
+ "config_names": {
47
+ "context_length": "max_position_embeddings",
48
+ "vocab_size": "vocab_size",
49
+ "width": "hidden_size",
50
+ "heads": "num_attention_heads",
51
+ "layers": "num_hidden_layers",
52
+ "layer_attr": "layer",
53
+ "token_embeddings_attr": "embeddings",
54
+ },
55
+ "pooler": "mean_pooler",
56
+ },
57
+ }
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_model.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ huggingface model adapter
2
+
3
+ Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
4
+ """
5
+
6
+ import re
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+ from torch import TensorType
12
+
13
+ try:
14
+ import transformers
15
+ from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig
16
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions
17
+ except ImportError as e:
18
+ transformers = None
19
+
20
+ class BaseModelOutput:
21
+ pass
22
+
23
+ class PretrainedConfig:
24
+ pass
25
+
26
+
27
+ from .hf_configs import arch_dict
28
+
29
+
30
+ # utils
31
+ def _camel2snake(s):
32
+ return re.sub(r"(?<!^)(?=[A-Z])", "_", s).lower()
33
+
34
+
35
+ # TODO: ?last - for gpt-like models
36
+ _POOLERS = {}
37
+
38
+
39
+ def register_pooler(cls):
40
+ """Decorator registering pooler class"""
41
+ _POOLERS[_camel2snake(cls.__name__)] = cls
42
+ return cls
43
+
44
+
45
+ @register_pooler
46
+ class MeanPooler(nn.Module):
47
+ """Mean pooling"""
48
+
49
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
50
+ masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
51
+ return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
52
+
53
+
54
+ @register_pooler
55
+ class MaxPooler(nn.Module):
56
+ """Max pooling"""
57
+
58
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
59
+ masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
60
+ return masked_output.max(1).values
61
+
62
+
63
+ @register_pooler
64
+ class ClsPooler(nn.Module):
65
+ """CLS token pooling"""
66
+
67
+ def __init__(self, use_pooler_output=True):
68
+ super().__init__()
69
+ self.cls_token_position = 0
70
+ self.use_pooler_output = use_pooler_output
71
+
72
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
73
+
74
+ if self.use_pooler_output and isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and (x.pooler_output is not None):
75
+ return x.pooler_output
76
+
77
+ return x.last_hidden_state[:, self.cls_token_position, :]
78
+
79
+
80
+ class HFTextEncoder(nn.Module):
81
+ """HuggingFace model adapter"""
82
+
83
+ def __init__(self, model_name_or_path: str, output_dim: int, tokenizer_name: str = None, config: PretrainedConfig = None, pooler_type: str = None, proj: str = None, pretrained: bool = True, masked_language_modeling: bool = False):
84
+ super().__init__()
85
+
86
+ self.output_dim = output_dim
87
+
88
+ # TODO: find better way to get this information
89
+ uses_transformer_pooler = pooler_type == "cls_pooler"
90
+
91
+ if transformers is None:
92
+ raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
93
+ if config is None:
94
+ self.config = AutoConfig.from_pretrained(model_name_or_path)
95
+ if masked_language_modeling:
96
+ create_func, model_args = (AutoModelForMaskedLM.from_pretrained, model_name_or_path) if pretrained else (AutoModelForMaskedLM.from_config, self.config)
97
+ else:
98
+ create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (AutoModel.from_config, self.config)
99
+ # TODO: do all model configs have this attribute? PretrainedConfig does so yes??
100
+ if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
101
+ self.transformer = create_func(model_args)
102
+ self.transformer = self.transformer.encoder
103
+ else:
104
+ self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
105
+ else:
106
+ self.config = config
107
+ if masked_language_modeling:
108
+ self.transformer = AutoModelForMaskedLM.from_config(config)
109
+ else:
110
+ self.transformer = AutoModel.from_config(config)
111
+
112
+ if pooler_type is None: # get default arch pooler
113
+ self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
114
+ else:
115
+ self.pooler = _POOLERS[pooler_type]()
116
+
117
+ d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
118
+ if (d_model == output_dim) and (proj is None): # do we always need a proj?
119
+ self.proj = nn.Identity()
120
+ elif proj == "linear":
121
+ self.proj = nn.Linear(d_model, output_dim, bias=False)
122
+ elif proj == "mlp":
123
+ hidden_size = (d_model + output_dim) // 2
124
+ self.proj = nn.Sequential(
125
+ nn.Linear(d_model, hidden_size, bias=False),
126
+ nn.GELU(),
127
+ nn.Linear(hidden_size, output_dim, bias=False),
128
+ )
129
+
130
+ # self.itm_proj = nn.Linear(d_model, 2, bias=False)
131
+ # self.mlm_proj = nn.Linear(d_model, self.config.vocab_size), bias=False)
132
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
133
+
134
+ # def forward_itm(self, x:TensorType, image_embeds:TensorType) -> TensorType:
135
+ # image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device)
136
+ # attn_mask = (x != self.config.pad_token_id).long()
137
+ # out = self.transformer(
138
+ # input_ids=x,
139
+ # attention_mask=attn_mask,
140
+ # encoder_hidden_states = image_embeds,
141
+ # encoder_attention_mask = image_atts,
142
+ # )
143
+ # pooled_out = self.pooler(out, attn_mask)
144
+
145
+ # return self.itm_proj(pooled_out)
146
+
147
+ def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None):
148
+ if masked_indices is None:
149
+ masked_indices = torch.bernoulli(probability_matrix).bool()
150
+
151
+ masked_indices[input_ids == self.tokenizer.pad_token_id] = False
152
+ masked_indices[input_ids == self.tokenizer.cls_token_id] = False
153
+
154
+ if targets is not None:
155
+ targets[~masked_indices] = -100 # We only compute loss on masked tokens
156
+
157
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
158
+ indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
159
+ input_ids[indices_replaced] = self.tokenizer.mask_token_id
160
+
161
+ # 10% of the time, we replace masked input tokens with random word
162
+ indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
163
+ random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
164
+ input_ids[indices_random] = random_words[indices_random]
165
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
166
+
167
+ if targets is not None:
168
+ return input_ids, targets
169
+ else:
170
+ return input_ids
171
+
172
+ def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25):
173
+ labels = input_ids.clone()
174
+ attn_mask = (input_ids != self.config.pad_token_id).long()
175
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(input_ids.device)
176
+ vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"])
177
+ probability_matrix = torch.full(labels.shape, mlm_probability)
178
+ input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels, probability_matrix=probability_matrix)
179
+ mlm_output = self.transformer(
180
+ input_ids,
181
+ attention_mask=attn_mask,
182
+ encoder_hidden_states=image_embeds,
183
+ encoder_attention_mask=image_atts,
184
+ return_dict=True,
185
+ labels=labels,
186
+ )
187
+ return mlm_output.loss
188
+ # mlm_output = self.transformer(input_ids,
189
+ # attention_mask = attn_mask,
190
+ # encoder_hidden_states = image_embeds,
191
+ # encoder_attention_mask = image_atts,
192
+ # return_dict = True,
193
+ # ).last_hidden_state
194
+ # logits = self.mlm_proj(mlm_output)
195
+
196
+ # # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size)
197
+ # logits = logits[:, 1:, :].contiguous().view(-1, vocab_size)
198
+ # labels = labels[:, 1:].contiguous().view(-1)
199
+
200
+ # mlm_loss = F.cross_entropy(
201
+ # logits,
202
+ # labels,
203
+ # # label_smoothing=0.1,
204
+ # )
205
+ # return mlm_loss
206
+
207
+ def forward(self, x: TensorType) -> TensorType:
208
+ attn_mask = (x != self.config.pad_token_id).long()
209
+ out = self.transformer(input_ids=x, attention_mask=attn_mask)
210
+ pooled_out = self.pooler(out, attn_mask)
211
+
212
+ return self.proj(pooled_out)
213
+
214
+ def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
215
+ if not unlocked_layers: # full freezing
216
+ for n, p in self.transformer.named_parameters():
217
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
218
+ return
219
+
220
+ encoder = self.transformer.encoder if hasattr(self.transformer, "encoder") else self.transformer
221
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
222
+ print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
223
+ embeddings = getattr(self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
224
+ modules = [embeddings, *layer_list][:-unlocked_layers]
225
+ # freeze layers
226
+ for module in modules:
227
+ for n, p in module.named_parameters():
228
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
229
+
230
+ @torch.jit.ignore
231
+ def set_grad_checkpointing(self, enable=True):
232
+ self.transformer.gradient_checkpointing_enable()
233
+
234
+ def get_num_layers(self):
235
+ encoder = self.transformer.encoder if hasattr(self.transformer, "encoder") else self.transformer
236
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
237
+ return len(layer_list)
238
+
239
+ def init_parameters(self):
240
+ pass
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/loss.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+
6
+ try:
7
+ import torch.distributed.nn
8
+ from torch import distributed as dist
9
+
10
+ has_distributed = True
11
+ except ImportError:
12
+ has_distributed = False
13
+
14
+ try:
15
+ import horovod.torch as hvd
16
+ except ImportError:
17
+ hvd = None
18
+
19
+ from timm.loss import LabelSmoothingCrossEntropy
20
+
21
+
22
+ def gather_features(image_features, text_features, local_loss=False, gather_with_grad=False, rank=0, world_size=1, use_horovod=False):
23
+ assert has_distributed, "torch.distributed did not import correctly, please use a PyTorch version with support."
24
+ if use_horovod:
25
+ assert hvd is not None, "Please install horovod"
26
+ if gather_with_grad:
27
+ all_image_features = hvd.allgather(image_features)
28
+ all_text_features = hvd.allgather(text_features)
29
+ else:
30
+ with torch.no_grad():
31
+ all_image_features = hvd.allgather(image_features)
32
+ all_text_features = hvd.allgather(text_features)
33
+ if not local_loss:
34
+ # ensure grads for local rank when all_* features don't have a gradient
35
+ gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
36
+ gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
37
+ gathered_image_features[rank] = image_features
38
+ gathered_text_features[rank] = text_features
39
+ all_image_features = torch.cat(gathered_image_features, dim=0)
40
+ all_text_features = torch.cat(gathered_text_features, dim=0)
41
+ else:
42
+ # We gather tensors from all gpus
43
+ if gather_with_grad:
44
+ all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
45
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
46
+ # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0)
47
+ # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0)
48
+ else:
49
+ gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
50
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
51
+ dist.all_gather(gathered_image_features, image_features)
52
+ dist.all_gather(gathered_text_features, text_features)
53
+ if not local_loss:
54
+ # ensure grads for local rank when all_* features don't have a gradient
55
+ gathered_image_features[rank] = image_features
56
+ gathered_text_features[rank] = text_features
57
+ all_image_features = torch.cat(gathered_image_features, dim=0)
58
+ all_text_features = torch.cat(gathered_text_features, dim=0)
59
+
60
+ return all_image_features, all_text_features
61
+
62
+
63
+ class ClipLoss(nn.Module):
64
+
65
+ def __init__(
66
+ self,
67
+ local_loss=False,
68
+ gather_with_grad=False,
69
+ cache_labels=False,
70
+ rank=0,
71
+ world_size=1,
72
+ use_horovod=False,
73
+ smoothing=0.0,
74
+ ):
75
+ super().__init__()
76
+ self.local_loss = local_loss
77
+ self.gather_with_grad = gather_with_grad
78
+ self.cache_labels = cache_labels
79
+ self.rank = rank
80
+ self.world_size = world_size
81
+ self.use_horovod = use_horovod
82
+ self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None
83
+
84
+ # cache state
85
+ self.prev_num_logits = 0
86
+ self.labels = {}
87
+
88
+ def forward(self, image_features, text_features, logit_scale=1.0):
89
+ device = image_features.device
90
+ if self.world_size > 1:
91
+ all_image_features, all_text_features = gather_features(image_features, text_features, self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
92
+
93
+ if self.local_loss:
94
+ logits_per_image = logit_scale * image_features @ all_text_features.T
95
+ logits_per_text = logit_scale * text_features @ all_image_features.T
96
+ else:
97
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
98
+ logits_per_text = logits_per_image.T
99
+ else:
100
+ logits_per_image = logit_scale * image_features @ text_features.T
101
+ logits_per_text = logit_scale * text_features @ image_features.T
102
+ # calculated ground-truth and cache if enabled
103
+ num_logits = logits_per_image.shape[0]
104
+ if self.prev_num_logits != num_logits or device not in self.labels:
105
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
106
+ if self.world_size > 1 and self.local_loss:
107
+ labels = labels + num_logits * self.rank
108
+ if self.cache_labels:
109
+ self.labels[device] = labels
110
+ self.prev_num_logits = num_logits
111
+ else:
112
+ labels = self.labels[device]
113
+
114
+ if self.label_smoothing_cross_entropy:
115
+ total_loss = (self.label_smoothing_cross_entropy(logits_per_image, labels) + self.label_smoothing_cross_entropy(logits_per_text, labels)) / 2
116
+ else:
117
+ total_loss = (F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels)) / 2
118
+
119
+ acc = None
120
+ i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image)
121
+ t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text)
122
+ acc = {"i2t": i2t_acc, "t2i": t2i_acc}
123
+ return total_loss, acc
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/model.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP Model
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+
6
+ import os
7
+ from dataclasses import dataclass
8
+ from typing import Optional, Tuple, Union
9
+ from functools import partial
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch import nn
15
+
16
+ try:
17
+ from .hf_model import HFTextEncoder
18
+ except:
19
+ HFTextEncoder = None
20
+ from .modified_resnet import ModifiedResNet
21
+ from .timm_model import TimmModel
22
+ from .eva_vit_model import EVAVisionTransformer
23
+ from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
24
+
25
+ try:
26
+ from apex.normalization import FusedLayerNorm
27
+ except:
28
+ FusedLayerNorm = LayerNorm
29
+ # print("Please 'pip install apex'")
30
+
31
+ try:
32
+ import xformers.ops as xops
33
+ except ImportError:
34
+ xops = None
35
+ # print("Please 'pip install xformers'")
36
+
37
+
38
+ class RMSnorm(nn.Module):
39
+ """
40
+ adepted from transformers T5LayerNorm
41
+ """
42
+
43
+ def __init__(self, hidden_size, eps=1e-6):
44
+ """
45
+ Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
46
+ """
47
+ super().__init__()
48
+ self.weight = nn.Parameter(torch.ones(hidden_size))
49
+ self.variance_epsilon = eps
50
+
51
+ def forward(self, hidden_states):
52
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
53
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
54
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
55
+ # half-precision inputs is done in fp32
56
+
57
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
58
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
59
+
60
+ # convert into half-precision if necessary
61
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
62
+ hidden_states = hidden_states.to(self.weight.dtype)
63
+
64
+ return self.weight * hidden_states
65
+
66
+
67
+ @dataclass
68
+ class CLIPVisionCfg:
69
+ layers: Union[Tuple[int, int, int, int], int] = 12
70
+ width: int = 768
71
+ head_width: int = 64
72
+ mlp_ratio: float = 4.0
73
+ patch_size: int = 16
74
+ image_size: Union[Tuple[int, int], int] = 224
75
+ ls_init_value: Optional[float] = None # layer scale initial value
76
+ patch_dropout: float = 0.0 # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
77
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
78
+ drop_path_rate: Optional[float] = None # drop path rate
79
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
80
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
81
+ timm_pool: str = "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
82
+ timm_proj: str = "linear" # linear projection for timm model output ('linear', 'mlp', '')
83
+ timm_proj_bias: bool = False # enable bias final projection
84
+ eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
85
+ qkv_bias: bool = True
86
+ fusedLN: bool = False
87
+ xattn: bool = False
88
+ postnorm: bool = False
89
+ rope: bool = False
90
+ pt_hw_seq_len: int = 16 # 224/14
91
+ intp_freq: bool = False
92
+ naiveswiglu: bool = False
93
+ subln: bool = False
94
+ use_rms_norm: bool = False
95
+
96
+
97
+ @dataclass
98
+ class CLIPTextCfg:
99
+ context_length: int = 77
100
+ vocab_size: int = 49408
101
+ width: int = 512
102
+ heads: int = 8
103
+ layers: int = 12
104
+ ls_init_value: Optional[float] = None # layer scale initial value
105
+ hf_model_name: str = None
106
+ hf_tokenizer_name: str = None
107
+ hf_model_pretrained: bool = True
108
+ proj: str = "mlp"
109
+ pooler_type: str = "mean_pooler"
110
+ masked_language_modeling: bool = False
111
+ fusedLN: bool = False
112
+ xattn: bool = False
113
+ attn_mask: bool = True
114
+
115
+
116
+ def get_cast_dtype(precision: str):
117
+ cast_dtype = None
118
+ if precision == "bf16":
119
+ cast_dtype = torch.bfloat16
120
+ elif precision == "fp16":
121
+ cast_dtype = torch.float16
122
+ return cast_dtype
123
+
124
+
125
+ def _build_vision_tower(embed_dim: int, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None):
126
+ if isinstance(vision_cfg, dict):
127
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
128
+
129
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
130
+ # memory efficient in recent PyTorch releases (>= 1.10).
131
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
132
+ act_layer = QuickGELU if quick_gelu else nn.GELU
133
+
134
+ if vision_cfg.eva_model_name:
135
+ vision_heads = vision_cfg.width // vision_cfg.head_width
136
+
137
+ norm_layer = RMSnorm if vision_cfg.use_rms_norm else LayerNorm
138
+
139
+ visual = EVAVisionTransformer(
140
+ img_size=vision_cfg.image_size,
141
+ patch_size=vision_cfg.patch_size,
142
+ num_classes=embed_dim,
143
+ use_mean_pooling=vision_cfg.global_average_pool, # False
144
+ init_values=vision_cfg.ls_init_value,
145
+ patch_dropout=vision_cfg.patch_dropout,
146
+ embed_dim=vision_cfg.width,
147
+ depth=vision_cfg.layers,
148
+ num_heads=vision_heads,
149
+ mlp_ratio=vision_cfg.mlp_ratio,
150
+ qkv_bias=vision_cfg.qkv_bias,
151
+ drop_path_rate=vision_cfg.drop_path_rate,
152
+ norm_layer=partial(norm_layer, eps=1e-6),
153
+ xattn=vision_cfg.xattn,
154
+ rope=vision_cfg.rope,
155
+ postnorm=vision_cfg.postnorm,
156
+ pt_hw_seq_len=vision_cfg.pt_hw_seq_len, # 224/14
157
+ intp_freq=vision_cfg.intp_freq,
158
+ naiveswiglu=vision_cfg.naiveswiglu,
159
+ subln=vision_cfg.subln,
160
+ )
161
+ elif vision_cfg.timm_model_name:
162
+ visual = TimmModel(
163
+ vision_cfg.timm_model_name, pretrained=vision_cfg.timm_model_pretrained, pool=vision_cfg.timm_pool, proj=vision_cfg.timm_proj, proj_bias=vision_cfg.timm_proj_bias, embed_dim=embed_dim, image_size=vision_cfg.image_size
164
+ )
165
+ act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
166
+ elif isinstance(vision_cfg.layers, (tuple, list)):
167
+ vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
168
+ visual = ModifiedResNet(layers=vision_cfg.layers, output_dim=embed_dim, heads=vision_heads, image_size=vision_cfg.image_size, width=vision_cfg.width)
169
+ else:
170
+ vision_heads = vision_cfg.width // vision_cfg.head_width
171
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
172
+ visual = VisionTransformer(
173
+ image_size=vision_cfg.image_size,
174
+ patch_size=vision_cfg.patch_size,
175
+ width=vision_cfg.width,
176
+ layers=vision_cfg.layers,
177
+ heads=vision_heads,
178
+ mlp_ratio=vision_cfg.mlp_ratio,
179
+ ls_init_value=vision_cfg.ls_init_value,
180
+ patch_dropout=vision_cfg.patch_dropout,
181
+ global_average_pool=vision_cfg.global_average_pool,
182
+ output_dim=embed_dim,
183
+ act_layer=act_layer,
184
+ norm_layer=norm_layer,
185
+ )
186
+
187
+ return visual
188
+
189
+
190
+ def _build_text_tower(
191
+ embed_dim: int,
192
+ text_cfg: CLIPTextCfg,
193
+ quick_gelu: bool = False,
194
+ cast_dtype: Optional[torch.dtype] = None,
195
+ ):
196
+ if isinstance(text_cfg, dict):
197
+ text_cfg = CLIPTextCfg(**text_cfg)
198
+
199
+ if text_cfg.hf_model_name:
200
+ text = HFTextEncoder(text_cfg.hf_model_name, output_dim=embed_dim, tokenizer_name=text_cfg.hf_tokenizer_name, proj=text_cfg.proj, pooler_type=text_cfg.pooler_type, masked_language_modeling=text_cfg.masked_language_modeling)
201
+ else:
202
+ act_layer = QuickGELU if quick_gelu else nn.GELU
203
+ norm_layer = LayerNorm
204
+
205
+ text = TextTransformer(
206
+ context_length=text_cfg.context_length,
207
+ vocab_size=text_cfg.vocab_size,
208
+ width=text_cfg.width,
209
+ heads=text_cfg.heads,
210
+ layers=text_cfg.layers,
211
+ ls_init_value=text_cfg.ls_init_value,
212
+ output_dim=embed_dim,
213
+ act_layer=act_layer,
214
+ norm_layer=FusedLayerNorm if text_cfg.fusedLN else norm_layer,
215
+ xattn=text_cfg.xattn,
216
+ attn_mask=text_cfg.attn_mask,
217
+ )
218
+ return text
219
+
220
+
221
+ class CLIP(nn.Module):
222
+ def __init__(
223
+ self,
224
+ embed_dim: int,
225
+ vision_cfg: CLIPVisionCfg,
226
+ text_cfg: CLIPTextCfg,
227
+ quick_gelu: bool = False,
228
+ cast_dtype: Optional[torch.dtype] = None,
229
+ ):
230
+ super().__init__()
231
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
232
+
233
+ text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
234
+ self.transformer = text.transformer
235
+ self.vocab_size = text.vocab_size
236
+ self.token_embedding = text.token_embedding
237
+ self.positional_embedding = text.positional_embedding
238
+ self.ln_final = text.ln_final
239
+ self.text_projection = text.text_projection
240
+ self.register_buffer("attn_mask", text.attn_mask, persistent=False)
241
+
242
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
243
+
244
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
245
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
246
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
247
+
248
+ @torch.jit.ignore
249
+ def set_grad_checkpointing(self, enable=True):
250
+ self.visual.set_grad_checkpointing(enable)
251
+ self.transformer.grad_checkpointing = enable
252
+
253
+ @torch.jit.ignore
254
+ def no_weight_decay(self):
255
+ return {"logit_scale"}
256
+
257
+ def encode_image(self, image, normalize: bool = False):
258
+ features = self.visual(image)
259
+ return F.normalize(features, dim=-1) if normalize else features
260
+
261
+ def encode_text(self, text, normalize: bool = False):
262
+ cast_dtype = self.transformer.get_cast_dtype()
263
+
264
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
265
+
266
+ x = x + self.positional_embedding.to(cast_dtype)
267
+ x = x.permute(1, 0, 2) # NLD -> LND
268
+ x = self.transformer(x, attn_mask=self.attn_mask)
269
+ x = x.permute(1, 0, 2) # LND -> NLD
270
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
271
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
272
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
273
+ return F.normalize(x, dim=-1) if normalize else x
274
+
275
+ def forward(self, image, text):
276
+ image_features = self.encode_image(image, normalize=True)
277
+ text_features = self.encode_text(text, normalize=True)
278
+ return image_features, text_features, self.logit_scale.exp()
279
+
280
+
281
+ class CustomCLIP(nn.Module):
282
+ def __init__(
283
+ self,
284
+ embed_dim: int,
285
+ vision_cfg: CLIPVisionCfg,
286
+ text_cfg: CLIPTextCfg,
287
+ quick_gelu: bool = False,
288
+ cast_dtype: Optional[torch.dtype] = None,
289
+ itm_task: bool = False,
290
+ ):
291
+ super().__init__()
292
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
293
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
294
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
295
+
296
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
297
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
298
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
299
+
300
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
301
+ self.text.lock(unlocked_layers, freeze_layer_norm)
302
+
303
+ @torch.jit.ignore
304
+ def set_grad_checkpointing(self, enable=True):
305
+ self.visual.set_grad_checkpointing(enable)
306
+ self.text.set_grad_checkpointing(enable)
307
+
308
+ @torch.jit.ignore
309
+ def no_weight_decay(self):
310
+ return {"logit_scale"}
311
+
312
+ def encode_image(self, image, normalize: bool = False):
313
+ features = self.visual(image)
314
+ return F.normalize(features, dim=-1) if normalize else features
315
+
316
+ def encode_text(self, text, normalize: bool = False):
317
+ features = self.text(text)
318
+ return F.normalize(features, dim=-1) if normalize else features
319
+
320
+ def forward(self, image, text):
321
+ image_features = self.encode_image(image, normalize=True)
322
+ text_features = self.encode_text(text, normalize=True)
323
+ return image_features, text_features, self.logit_scale.exp()
324
+
325
+
326
+ def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
327
+ """Convert applicable model parameters to low-precision (bf16 or fp16)"""
328
+
329
+ def _convert_weights(l):
330
+
331
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
332
+ l.weight.data = l.weight.data.to(dtype)
333
+ if l.bias is not None:
334
+ l.bias.data = l.bias.data.to(dtype)
335
+
336
+ if isinstance(l, (nn.MultiheadAttention, Attention)):
337
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
338
+ tensor = getattr(l, attr, None)
339
+ if tensor is not None:
340
+ tensor.data = tensor.data.to(dtype)
341
+
342
+ if isinstance(l, nn.Parameter):
343
+ l.data = l.data.to(dtype)
344
+
345
+ for name in ["text_projection", "proj"]:
346
+ if hasattr(l, name) and isinstance(l, nn.Parameter):
347
+ attr = getattr(l, name, None)
348
+ if attr is not None:
349
+ attr.data = attr.data.to(dtype)
350
+
351
+ model.apply(_convert_weights)
352
+
353
+
354
+ convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
355
+
356
+
357
+ # used to maintain checkpoint compatibility
358
+ def convert_to_custom_text_state_dict(state_dict: dict):
359
+ if "text_projection" in state_dict:
360
+ # old format state_dict, move text tower -> .text
361
+ new_state_dict = {}
362
+ for k, v in state_dict.items():
363
+ if any(k.startswith(p) for p in ("text_projection", "positional_embedding", "token_embedding", "transformer", "ln_final", "logit_scale")):
364
+ k = "text." + k
365
+ new_state_dict[k] = v
366
+ return new_state_dict
367
+ return state_dict
368
+
369
+
370
+ def build_model_from_openai_state_dict(
371
+ state_dict: dict,
372
+ quick_gelu=True,
373
+ cast_dtype=torch.float16,
374
+ ):
375
+ vit = "visual.proj" in state_dict
376
+
377
+ if vit:
378
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
379
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
380
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
381
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
382
+ image_size = vision_patch_size * grid_size
383
+ else:
384
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
385
+ vision_layers = tuple(counts)
386
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
387
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
388
+ vision_patch_size = None
389
+ assert output_width**2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
390
+ image_size = output_width * 32
391
+
392
+ embed_dim = state_dict["text_projection"].shape[1]
393
+ context_length = state_dict["positional_embedding"].shape[0]
394
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
395
+ transformer_width = state_dict["ln_final.weight"].shape[0]
396
+ transformer_heads = transformer_width // 64
397
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
398
+
399
+ vision_cfg = CLIPVisionCfg(
400
+ layers=vision_layers,
401
+ width=vision_width,
402
+ patch_size=vision_patch_size,
403
+ image_size=image_size,
404
+ )
405
+ text_cfg = CLIPTextCfg(context_length=context_length, vocab_size=vocab_size, width=transformer_width, heads=transformer_heads, layers=transformer_layers)
406
+ model = CLIP(
407
+ embed_dim,
408
+ vision_cfg=vision_cfg,
409
+ text_cfg=text_cfg,
410
+ quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
411
+ cast_dtype=cast_dtype,
412
+ )
413
+
414
+ for key in ["input_resolution", "context_length", "vocab_size"]:
415
+ state_dict.pop(key, None)
416
+
417
+ convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
418
+ model.load_state_dict(state_dict)
419
+ return model.eval()
420
+
421
+
422
+ def trace_model(model, batch_size=256, device=torch.device("cpu")):
423
+ model.eval()
424
+ image_size = model.visual.image_size
425
+ example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
426
+ example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
427
+ model = torch.jit.trace_module(model, inputs=dict(forward=(example_images, example_text), encode_text=(example_text,), encode_image=(example_images,)))
428
+ model.visual.image_size = image_size
429
+ return model
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/modified_resnet.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from .utils import freeze_batch_norm_2d
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+ self.act1 = nn.ReLU(inplace=True)
20
+
21
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22
+ self.bn2 = nn.BatchNorm2d(planes)
23
+ self.act2 = nn.ReLU(inplace=True)
24
+
25
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26
+
27
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29
+ self.act3 = nn.ReLU(inplace=True)
30
+
31
+ self.downsample = None
32
+ self.stride = stride
33
+
34
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
35
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36
+ self.downsample = nn.Sequential(OrderedDict([("-1", nn.AvgPool2d(stride)), ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), ("1", nn.BatchNorm2d(planes * self.expansion))]))
37
+
38
+ def forward(self, x: torch.Tensor):
39
+ identity = x
40
+
41
+ out = self.act1(self.bn1(self.conv1(x)))
42
+ out = self.act2(self.bn2(self.conv2(out)))
43
+ out = self.avgpool(out)
44
+ out = self.bn3(self.conv3(out))
45
+
46
+ if self.downsample is not None:
47
+ identity = self.downsample(x)
48
+
49
+ out += identity
50
+ out = self.act3(out)
51
+ return out
52
+
53
+
54
+ class AttentionPool2d(nn.Module):
55
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
56
+ super().__init__()
57
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
58
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
59
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
60
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
61
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
62
+ self.num_heads = num_heads
63
+
64
+ def forward(self, x):
65
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
66
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
67
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
68
+ x, _ = F.multi_head_attention_forward(
69
+ query=x,
70
+ key=x,
71
+ value=x,
72
+ embed_dim_to_check=x.shape[-1],
73
+ num_heads=self.num_heads,
74
+ q_proj_weight=self.q_proj.weight,
75
+ k_proj_weight=self.k_proj.weight,
76
+ v_proj_weight=self.v_proj.weight,
77
+ in_proj_weight=None,
78
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
79
+ bias_k=None,
80
+ bias_v=None,
81
+ add_zero_attn=False,
82
+ dropout_p=0.0,
83
+ out_proj_weight=self.c_proj.weight,
84
+ out_proj_bias=self.c_proj.bias,
85
+ use_separate_proj_weight=True,
86
+ training=self.training,
87
+ need_weights=False,
88
+ )
89
+
90
+ return x[0]
91
+
92
+
93
+ class ModifiedResNet(nn.Module):
94
+ """
95
+ A ResNet class that is similar to torchvision's but contains the following changes:
96
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
97
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
98
+ - The final pooling layer is a QKV attention instead of an average pool
99
+ """
100
+
101
+ def __init__(self, layers, output_dim, heads, image_size=224, width=64):
102
+ super().__init__()
103
+ self.output_dim = output_dim
104
+ self.image_size = image_size
105
+
106
+ # the 3-layer stem
107
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
108
+ self.bn1 = nn.BatchNorm2d(width // 2)
109
+ self.act1 = nn.ReLU(inplace=True)
110
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
111
+ self.bn2 = nn.BatchNorm2d(width // 2)
112
+ self.act2 = nn.ReLU(inplace=True)
113
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
114
+ self.bn3 = nn.BatchNorm2d(width)
115
+ self.act3 = nn.ReLU(inplace=True)
116
+ self.avgpool = nn.AvgPool2d(2)
117
+
118
+ # residual layers
119
+ self._inplanes = width # this is a *mutable* variable used during construction
120
+ self.layer1 = self._make_layer(width, layers[0])
121
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
122
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
123
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
124
+
125
+ embed_dim = width * 32 # the ResNet feature dimension
126
+ self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
127
+
128
+ self.init_parameters()
129
+
130
+ def _make_layer(self, planes, blocks, stride=1):
131
+ layers = [Bottleneck(self._inplanes, planes, stride)]
132
+
133
+ self._inplanes = planes * Bottleneck.expansion
134
+ for _ in range(1, blocks):
135
+ layers.append(Bottleneck(self._inplanes, planes))
136
+
137
+ return nn.Sequential(*layers)
138
+
139
+ def init_parameters(self):
140
+ if self.attnpool is not None:
141
+ std = self.attnpool.c_proj.in_features**-0.5
142
+ nn.init.normal_(self.attnpool.q_proj.weight, std=std)
143
+ nn.init.normal_(self.attnpool.k_proj.weight, std=std)
144
+ nn.init.normal_(self.attnpool.v_proj.weight, std=std)
145
+ nn.init.normal_(self.attnpool.c_proj.weight, std=std)
146
+
147
+ for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
148
+ for name, param in resnet_block.named_parameters():
149
+ if name.endswith("bn3.weight"):
150
+ nn.init.zeros_(param)
151
+
152
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
153
+ assert unlocked_groups == 0, "partial locking not currently supported for this model"
154
+ for param in self.parameters():
155
+ param.requires_grad = False
156
+ if freeze_bn_stats:
157
+ freeze_batch_norm_2d(self)
158
+
159
+ @torch.jit.ignore
160
+ def set_grad_checkpointing(self, enable=True):
161
+ # FIXME support for non-transformer
162
+ pass
163
+
164
+ def stem(self, x):
165
+ x = self.act1(self.bn1(self.conv1(x)))
166
+ x = self.act2(self.bn2(self.conv2(x)))
167
+ x = self.act3(self.bn3(self.conv3(x)))
168
+ x = self.avgpool(x)
169
+ return x
170
+
171
+ def forward(self, x):
172
+ x = self.stem(x)
173
+ x = self.layer1(x)
174
+ x = self.layer2(x)
175
+ x = self.layer3(x)
176
+ x = self.layer4(x)
177
+ x = self.attnpool(x)
178
+
179
+ return x
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/openai.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ OpenAI pretrained model functions
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+
6
+ import os
7
+ import warnings
8
+ from typing import List, Optional, Union
9
+
10
+ import torch
11
+
12
+ from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
13
+ from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
14
+
15
+ __all__ = ["list_openai_models", "load_openai_model"]
16
+
17
+
18
+ def list_openai_models() -> List[str]:
19
+ """Returns the names of available CLIP models"""
20
+ return list_pretrained_models_by_tag("openai")
21
+
22
+
23
+ def load_openai_model(
24
+ name: str,
25
+ precision: Optional[str] = None,
26
+ device: Optional[Union[str, torch.device]] = None,
27
+ jit: bool = True,
28
+ cache_dir: Optional[str] = None,
29
+ ):
30
+ """Load a CLIP model
31
+
32
+ Parameters
33
+ ----------
34
+ name : str
35
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
36
+ precision: str
37
+ Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
38
+ device : Union[str, torch.device]
39
+ The device to put the loaded model
40
+ jit : bool
41
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
42
+ cache_dir : Optional[str]
43
+ The directory to cache the downloaded model weights
44
+
45
+ Returns
46
+ -------
47
+ model : torch.nn.Module
48
+ The CLIP model
49
+ preprocess : Callable[[PIL.Image], torch.Tensor]
50
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
51
+ """
52
+ if device is None:
53
+ device = "cuda" if torch.cuda.is_available() else "cpu"
54
+ if precision is None:
55
+ precision = "fp32" if device == "cpu" else "fp16"
56
+
57
+ if get_pretrained_url(name, "openai"):
58
+ model_path = download_pretrained_from_url(get_pretrained_url(name, "openai"), cache_dir=cache_dir)
59
+ elif os.path.isfile(name):
60
+ model_path = name
61
+ else:
62
+ raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
63
+
64
+ try:
65
+ # loading JIT archive
66
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
67
+ state_dict = None
68
+ except RuntimeError:
69
+ # loading saved state dict
70
+ if jit:
71
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
72
+ jit = False
73
+ state_dict = torch.load(model_path, map_location="cpu")
74
+
75
+ if not jit:
76
+ # Build a non-jit model from the OpenAI jitted model state dict
77
+ cast_dtype = get_cast_dtype(precision)
78
+ try:
79
+ model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
80
+ except KeyError:
81
+ sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
82
+ model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
83
+
84
+ # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
85
+ model = model.to(device)
86
+ if precision.startswith("amp") or precision == "fp32":
87
+ model.float()
88
+ elif precision == "bf16":
89
+ convert_weights_to_lp(model, dtype=torch.bfloat16)
90
+
91
+ return model
92
+
93
+ # patch the device names
94
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
95
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
96
+
97
+ def patch_device(module):
98
+ try:
99
+ graphs = [module.graph] if hasattr(module, "graph") else []
100
+ except RuntimeError:
101
+ graphs = []
102
+
103
+ if hasattr(module, "forward1"):
104
+ graphs.append(module.forward1.graph)
105
+
106
+ for graph in graphs:
107
+ for node in graph.findAllNodes("prim::Constant"):
108
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
109
+ node.copyAttributes(device_node)
110
+
111
+ model.apply(patch_device)
112
+ patch_device(model.encode_image)
113
+ patch_device(model.encode_text)
114
+
115
+ # patch dtype to float32 (typically for CPU)
116
+ if precision == "fp32":
117
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
118
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
119
+ float_node = float_input.node()
120
+
121
+ def patch_float(module):
122
+ try:
123
+ graphs = [module.graph] if hasattr(module, "graph") else []
124
+ except RuntimeError:
125
+ graphs = []
126
+
127
+ if hasattr(module, "forward1"):
128
+ graphs.append(module.forward1.graph)
129
+
130
+ for graph in graphs:
131
+ for node in graph.findAllNodes("aten::to"):
132
+ inputs = list(node.inputs())
133
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
134
+ if inputs[i].node()["value"] == 5:
135
+ inputs[i].node().copyAttributes(float_node)
136
+
137
+ model.apply(patch_float)
138
+ patch_float(model.encode_image)
139
+ patch_float(model.encode_text)
140
+ model.float()
141
+
142
+ # ensure image_size attr available at consistent location for both jit and non-jit
143
+ model.visual.image_size = model.input_resolution.item()
144
+ return model
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/pretrained.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Dict, Union
6
+
7
+ from tqdm import tqdm
8
+
9
+ try:
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ _has_hf_hub = True
13
+ except ImportError:
14
+ hf_hub_download = None
15
+ _has_hf_hub = False
16
+
17
+
18
+ def _pcfg(url="", hf_hub="", filename="", mean=None, std=None):
19
+ return dict(
20
+ url=url,
21
+ hf_hub=hf_hub,
22
+ mean=mean,
23
+ std=std,
24
+ )
25
+
26
+
27
+ _VITB32 = dict(
28
+ openai=_pcfg("https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
29
+ laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
30
+ laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
31
+ laion2b_e16=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
32
+ laion2b_s34b_b79k=_pcfg(hf_hub="laion/CLIP-ViT-B-32-laion2B-s34B-b79K/"),
33
+ )
34
+
35
+ _VITB32_quickgelu = dict(
36
+ openai=_pcfg("https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
37
+ laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
38
+ laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
39
+ )
40
+
41
+ _VITB16 = dict(
42
+ openai=_pcfg("https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
43
+ laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
44
+ laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
45
+ laion2b_s34b_b88k=_pcfg(hf_hub="laion/CLIP-ViT-B-16-laion2B-s34B-b88K/"),
46
+ )
47
+
48
+ _EVAB16 = dict(
49
+ eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt"),
50
+ eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt"),
51
+ eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt"),
52
+ eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt"),
53
+ )
54
+
55
+ _VITB16_PLUS_240 = dict(
56
+ laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
57
+ laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
58
+ )
59
+
60
+ _VITL14 = dict(
61
+ openai=_pcfg("https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
62
+ laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
63
+ laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
64
+ laion2b_s32b_b82k=_pcfg(hf_hub="laion/CLIP-ViT-L-14-laion2B-s32B-b82K/", mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
65
+ )
66
+
67
+ _EVAL14 = dict(
68
+ eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_L_psz14.pt"),
69
+ eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_L_psz14.pt"),
70
+ eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt"),
71
+ eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt"),
72
+ )
73
+
74
+ _VITL14_336 = dict(
75
+ openai=_pcfg("https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
76
+ )
77
+
78
+ _EVAL14_336 = dict(
79
+ eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt"),
80
+ eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt"),
81
+ eva_clip_224to336=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt"),
82
+ eva02_clip_224to336=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt"),
83
+ )
84
+
85
+ _VITH14 = dict(
86
+ laion2b_s32b_b79k=_pcfg(hf_hub="laion/CLIP-ViT-H-14-laion2B-s32B-b79K/"),
87
+ )
88
+
89
+ _VITg14 = dict(
90
+ laion2b_s12b_b42k=_pcfg(hf_hub="laion/CLIP-ViT-g-14-laion2B-s12B-b42K/"),
91
+ laion2b_s34b_b88k=_pcfg(hf_hub="laion/CLIP-ViT-g-14-laion2B-s34B-b88K/"),
92
+ )
93
+
94
+ _EVAg14 = dict(
95
+ eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/"),
96
+ eva01=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_g_psz14.pt"),
97
+ eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt"),
98
+ eva01_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt"),
99
+ )
100
+
101
+ _EVAg14_PLUS = dict(
102
+ eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/"),
103
+ eva01=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_g_psz14.pt"),
104
+ eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt"),
105
+ eva01_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt"),
106
+ )
107
+
108
+ _VITbigG14 = dict(
109
+ laion2b_s39b_b160k=_pcfg(hf_hub="laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/"),
110
+ )
111
+
112
+ _EVAbigE14 = dict(
113
+ eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"),
114
+ eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"),
115
+ eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt"),
116
+ eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt"),
117
+ )
118
+
119
+ _EVAbigE14_PLUS = dict(
120
+ eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"),
121
+ eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"),
122
+ eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt"),
123
+ eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt"),
124
+ )
125
+
126
+ _EVA_8B = dict(
127
+ eva=_pcfg(hf_hub="BAAI/EVA-CLIP-8B/EVA_8B_psz14.bin"),
128
+ eva_clip=_pcfg(hf_hub="BAAI/EVA-CLIP-8B/EVA_CLIP_8B_psz14_s9B.pt"),
129
+ )
130
+
131
+ _EVA_8B_PLUS = dict(
132
+ eva_clip=_pcfg(hf_hub="BAAI/EVA-CLIP-8B-448/EVA_CLIP_8B_psz14_plus_s0.6B.pt"),
133
+ )
134
+
135
+
136
+ _PRETRAINED = {
137
+ # "ViT-B-32": _VITB32,
138
+ "OpenaiCLIP-B-32": _VITB32,
139
+ "OpenCLIP-B-32": _VITB32,
140
+ # "ViT-B-32-quickgelu": _VITB32_quickgelu,
141
+ "OpenaiCLIP-B-32-quickgelu": _VITB32_quickgelu,
142
+ "OpenCLIP-B-32-quickgelu": _VITB32_quickgelu,
143
+ # "ViT-B-16": _VITB16,
144
+ "OpenaiCLIP-B-16": _VITB16,
145
+ "OpenCLIP-B-16": _VITB16,
146
+ "EVA02-B-16": _EVAB16,
147
+ "EVA02-CLIP-B-16": _EVAB16,
148
+ # "ViT-B-16-plus-240": _VITB16_PLUS_240,
149
+ "OpenCLIP-B-16-plus-240": _VITB16_PLUS_240,
150
+ # "ViT-L-14": _VITL14,
151
+ "OpenaiCLIP-L-14": _VITL14,
152
+ "OpenCLIP-L-14": _VITL14,
153
+ "EVA02-L-14": _EVAL14,
154
+ "EVA02-CLIP-L-14": _EVAL14,
155
+ # "ViT-L-14-336": _VITL14_336,
156
+ "OpenaiCLIP-L-14-336": _VITL14_336,
157
+ "EVA02-CLIP-L-14-336": _EVAL14_336,
158
+ # "ViT-H-14": _VITH14,
159
+ # "ViT-g-14": _VITg14,
160
+ "OpenCLIP-H-14": _VITH14,
161
+ "OpenCLIP-g-14": _VITg14,
162
+ "EVA01-CLIP-g-14": _EVAg14,
163
+ "EVA01-CLIP-g-14-plus": _EVAg14_PLUS,
164
+ # "ViT-bigG-14": _VITbigG14,
165
+ "OpenCLIP-bigG-14": _VITbigG14,
166
+ "EVA02-CLIP-bigE-14": _EVAbigE14,
167
+ "EVA02-CLIP-bigE-14-plus": _EVAbigE14_PLUS,
168
+ "EVA-CLIP-8B": _EVA_8B,
169
+ "EVA-CLIP-8B-448": _EVA_8B_PLUS,
170
+ "EVA-CLIP-8B-plus": _EVA_8B_PLUS,
171
+ }
172
+
173
+
174
+ def _clean_tag(tag: str):
175
+ # normalize pretrained tags
176
+ return tag.lower().replace("-", "_")
177
+
178
+
179
+ def list_pretrained(as_str: bool = False):
180
+ """returns list of pretrained models
181
+ Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
182
+ """
183
+ return [":".join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
184
+
185
+
186
+ def list_pretrained_models_by_tag(tag: str):
187
+ """return all models having the specified pretrain tag"""
188
+ models = []
189
+ tag = _clean_tag(tag)
190
+ for k in _PRETRAINED.keys():
191
+ if tag in _PRETRAINED[k]:
192
+ models.append(k)
193
+ return models
194
+
195
+
196
+ def list_pretrained_tags_by_model(model: str):
197
+ """return all pretrain tags for the specified model architecture"""
198
+ tags = []
199
+ if model in _PRETRAINED:
200
+ tags.extend(_PRETRAINED[model].keys())
201
+ return tags
202
+
203
+
204
+ def is_pretrained_cfg(model: str, tag: str):
205
+ if model not in _PRETRAINED:
206
+ return False
207
+ return _clean_tag(tag) in _PRETRAINED[model]
208
+
209
+
210
+ def get_pretrained_cfg(model: str, tag: str):
211
+ if model not in _PRETRAINED:
212
+ return {}
213
+ model_pretrained = _PRETRAINED[model]
214
+ return model_pretrained.get(_clean_tag(tag), {})
215
+
216
+
217
+ def get_pretrained_url(model: str, tag: str):
218
+ cfg = get_pretrained_cfg(model, _clean_tag(tag))
219
+ return cfg.get("url", "")
220
+
221
+
222
+ def download_pretrained_from_url(
223
+ url: str,
224
+ cache_dir: Union[str, None] = None,
225
+ ):
226
+ if not cache_dir:
227
+ cache_dir = os.path.expanduser("~/.cache/clip")
228
+ os.makedirs(cache_dir, exist_ok=True)
229
+ filename = os.path.basename(url)
230
+
231
+ if "openaipublic" in url:
232
+ expected_sha256 = url.split("/")[-2]
233
+ elif "mlfoundations" in url:
234
+ expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
235
+ else:
236
+ expected_sha256 = ""
237
+
238
+ download_target = os.path.join(cache_dir, filename)
239
+
240
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
241
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
242
+
243
+ if os.path.isfile(download_target):
244
+ if expected_sha256:
245
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
246
+ return download_target
247
+ else:
248
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
249
+ else:
250
+ return download_target
251
+
252
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
253
+ with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit="iB", unit_scale=True) as loop:
254
+ while True:
255
+ buffer = source.read(8192)
256
+ if not buffer:
257
+ break
258
+
259
+ output.write(buffer)
260
+ loop.update(len(buffer))
261
+
262
+ if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
263
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
264
+
265
+ return download_target
266
+
267
+
268
+ def has_hf_hub(necessary=False):
269
+ if not _has_hf_hub and necessary:
270
+ # if no HF Hub module installed, and it is necessary to continue, raise error
271
+ raise RuntimeError("Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.")
272
+ return _has_hf_hub
273
+
274
+
275
+ def download_pretrained_from_hf(
276
+ model_id: str,
277
+ filename: str = "open_clip_pytorch_model.bin",
278
+ revision=None,
279
+ cache_dir: Union[str, None] = None,
280
+ ):
281
+ has_hf_hub(True)
282
+ cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
283
+ return cached_file
284
+
285
+
286
+ def download_pretrained(
287
+ cfg: Dict,
288
+ force_hf_hub: bool = False,
289
+ cache_dir: Union[str, None] = None,
290
+ ):
291
+ target = ""
292
+ if not cfg:
293
+ return target
294
+
295
+ download_url = cfg.get("url", "")
296
+ download_hf_hub = cfg.get("hf_hub", "")
297
+ if download_hf_hub and force_hf_hub:
298
+ # use HF hub even if url exists
299
+ download_url = ""
300
+
301
+ if download_url:
302
+ target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
303
+ elif download_hf_hub:
304
+ has_hf_hub(True)
305
+ # we assume the hf_hub entries in pretrained config combine model_id + filename in
306
+ # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
307
+ # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
308
+ model_id, filename = os.path.split(download_hf_hub)
309
+ if filename:
310
+ target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
311
+ else:
312
+ target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
313
+
314
+ return target
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/rope.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi
2
+ import torch
3
+ from torch import nn
4
+ from einops import rearrange, repeat
5
+ import logging
6
+
7
+
8
+ def broadcat(tensors, dim=-1):
9
+ num_tensors = len(tensors)
10
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
11
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
12
+ shape_len = list(shape_lens)[0]
13
+ dim = (dim + shape_len) if dim < 0 else dim
14
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
15
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
16
+ assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), "invalid dimensions for broadcastable concatentation"
17
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
18
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
19
+ expanded_dims.insert(dim, (dim, dims[dim]))
20
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
21
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
22
+ return torch.cat(tensors, dim=dim)
23
+
24
+
25
+ def rotate_half(x):
26
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
27
+ x1, x2 = x.unbind(dim=-1)
28
+ x = torch.stack((-x2, x1), dim=-1)
29
+ return rearrange(x, "... d r -> ... (d r)")
30
+
31
+
32
+ class VisionRotaryEmbedding(nn.Module):
33
+ def __init__(
34
+ self,
35
+ dim,
36
+ pt_seq_len,
37
+ ft_seq_len=None,
38
+ custom_freqs=None,
39
+ freqs_for="lang",
40
+ theta=10000,
41
+ max_freq=10,
42
+ num_freqs=1,
43
+ ):
44
+ super().__init__()
45
+ if custom_freqs:
46
+ freqs = custom_freqs
47
+ elif freqs_for == "lang":
48
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
49
+ elif freqs_for == "pixel":
50
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
51
+ elif freqs_for == "constant":
52
+ freqs = torch.ones(num_freqs).float()
53
+ else:
54
+ raise ValueError(f"unknown modality {freqs_for}")
55
+
56
+ if ft_seq_len is None:
57
+ ft_seq_len = pt_seq_len
58
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
59
+
60
+ freqs_h = torch.einsum("..., f -> ... f", t, freqs)
61
+ freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
62
+
63
+ freqs_w = torch.einsum("..., f -> ... f", t, freqs)
64
+ freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
65
+
66
+ freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
67
+
68
+ self.register_buffer("freqs_cos", freqs.cos())
69
+ self.register_buffer("freqs_sin", freqs.sin())
70
+
71
+ logging.info(f"Shape of rope freq: {self.freqs_cos.shape}")
72
+
73
+ def forward(self, t, start_index=0):
74
+ rot_dim = self.freqs_cos.shape[-1]
75
+ end_index = start_index + rot_dim
76
+ assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
77
+ t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
78
+ t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
79
+
80
+ return torch.cat((t_left, t, t_right), dim=-1)
81
+
82
+
83
+ class VisionRotaryEmbeddingFast(nn.Module):
84
+ def __init__(self, dim, pt_seq_len, ft_seq_len=None, custom_freqs=None, freqs_for="lang", theta=10000, max_freq=10, num_freqs=1, patch_dropout=0.0):
85
+ super().__init__()
86
+ if custom_freqs:
87
+ freqs = custom_freqs
88
+ elif freqs_for == "lang":
89
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
90
+ elif freqs_for == "pixel":
91
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
92
+ elif freqs_for == "constant":
93
+ freqs = torch.ones(num_freqs).float()
94
+ else:
95
+ raise ValueError(f"unknown modality {freqs_for}")
96
+
97
+ if ft_seq_len is None:
98
+ ft_seq_len = pt_seq_len
99
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
100
+
101
+ freqs = torch.einsum("..., f -> ... f", t, freqs)
102
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
103
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
104
+
105
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
106
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
107
+
108
+ self.patch_dropout = patch_dropout
109
+
110
+ self.register_buffer("freqs_cos", freqs_cos)
111
+ self.register_buffer("freqs_sin", freqs_sin)
112
+
113
+ logging.info(f"Shape of rope freq: {self.freqs_cos.shape}")
114
+
115
+ def forward(self, t, patch_indices_keep=None):
116
+ if patch_indices_keep is not None:
117
+ batch = t.size()[0]
118
+ batch_indices = torch.arange(batch)
119
+ batch_indices = batch_indices[..., None]
120
+
121
+ freqs_cos = repeat(self.freqs_cos, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
122
+ freqs_sin = repeat(self.freqs_sin, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
123
+
124
+ freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
125
+ freqs_cos = rearrange(freqs_cos, "n i m j -> n m i j")
126
+ freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
127
+ freqs_sin = rearrange(freqs_sin, "n i m j -> n m i j")
128
+
129
+ return t * freqs_cos + rotate_half(t) * freqs_sin
130
+
131
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/timm_model.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ timm model adapter
2
+
3
+ Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4
+ """
5
+
6
+ import logging
7
+ from collections import OrderedDict
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ try:
13
+ import timm
14
+ from timm.models.layers import Mlp, to_2tuple
15
+
16
+ try:
17
+ # old timm imports < 0.8.1
18
+ from timm.models.layers.attention_pool2d import RotAttentionPool2d
19
+ from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
20
+ except ImportError:
21
+ # new timm imports >= 0.8.1
22
+ from timm.layers import RotAttentionPool2d
23
+ from timm.layers import AttentionPool2d as AbsAttentionPool2d
24
+ except ImportError:
25
+ timm = None
26
+
27
+ from .utils import freeze_batch_norm_2d
28
+
29
+
30
+ class TimmModel(nn.Module):
31
+ """timm model adapter
32
+ # FIXME this adapter is a work in progress, may change in ways that break weight compat
33
+ """
34
+
35
+ def __init__(self, model_name, embed_dim, image_size=224, pool="avg", proj="linear", proj_bias=False, drop=0.0, pretrained=False):
36
+ super().__init__()
37
+ if timm is None:
38
+ raise RuntimeError("Please `pip install timm` to use timm models.")
39
+
40
+ self.image_size = to_2tuple(image_size)
41
+ self.trunk = timm.create_model(model_name, pretrained=pretrained)
42
+ feat_size = self.trunk.default_cfg.get("pool_size", None)
43
+ feature_ndim = 1 if not feat_size else 2
44
+ if pool in ("abs_attn", "rot_attn"):
45
+ assert feature_ndim == 2
46
+ # if attn pooling used, remove both classifier and default pool
47
+ self.trunk.reset_classifier(0, global_pool="")
48
+ else:
49
+ # reset global pool if pool config set, otherwise leave as network default
50
+ reset_kwargs = dict(global_pool=pool) if pool else {}
51
+ self.trunk.reset_classifier(0, **reset_kwargs)
52
+ prev_chs = self.trunk.num_features
53
+
54
+ head_layers = OrderedDict()
55
+ if pool == "abs_attn":
56
+ head_layers["pool"] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
57
+ prev_chs = embed_dim
58
+ elif pool == "rot_attn":
59
+ head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
60
+ prev_chs = embed_dim
61
+ else:
62
+ assert proj, "projection layer needed if non-attention pooling is used."
63
+
64
+ # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
65
+ if proj == "linear":
66
+ head_layers["drop"] = nn.Dropout(drop)
67
+ head_layers["proj"] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
68
+ elif proj == "mlp":
69
+ head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias))
70
+
71
+ self.head = nn.Sequential(head_layers)
72
+
73
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
74
+ """lock modules
75
+ Args:
76
+ unlocked_groups (int): leave last n layer groups unlocked (default: 0)
77
+ """
78
+ if not unlocked_groups:
79
+ # lock full model
80
+ for param in self.trunk.parameters():
81
+ param.requires_grad = False
82
+ if freeze_bn_stats:
83
+ freeze_batch_norm_2d(self.trunk)
84
+ else:
85
+ # NOTE: partial freeze requires latest timm (master) branch and is subject to change
86
+ try:
87
+ # FIXME import here until API stable and in an official release
88
+ from timm.models.helpers import group_parameters, group_modules
89
+ except ImportError:
90
+ raise RuntimeError("Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`")
91
+ matcher = self.trunk.group_matcher()
92
+ gparams = group_parameters(self.trunk, matcher)
93
+ max_layer_id = max(gparams.keys())
94
+ max_layer_id = max_layer_id - unlocked_groups
95
+ for group_idx in range(max_layer_id + 1):
96
+ group = gparams[group_idx]
97
+ for param in group:
98
+ self.trunk.get_parameter(param).requires_grad = False
99
+ if freeze_bn_stats:
100
+ gmodules = group_modules(self.trunk, matcher, reverse=True)
101
+ gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
102
+ freeze_batch_norm_2d(self.trunk, gmodules)
103
+
104
+ @torch.jit.ignore
105
+ def set_grad_checkpointing(self, enable=True):
106
+ try:
107
+ self.trunk.set_grad_checkpointing(enable)
108
+ except Exception as e:
109
+ logging.warning("grad checkpointing not supported for this timm image tower, continuing without...")
110
+
111
+ def forward(self, x):
112
+ x = self.trunk(x)
113
+ x = self.head(x)
114
+ return x
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/tokenizer.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP tokenizer
2
+
3
+ Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+
6
+ import gzip
7
+ import html
8
+ import os
9
+ from functools import lru_cache
10
+ from typing import Union, List
11
+
12
+ import ftfy
13
+ import regex as re
14
+ import torch
15
+
16
+ # https://stackoverflow.com/q/62691279
17
+ import os
18
+
19
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
20
+
21
+
22
+ @lru_cache()
23
+ def default_bpe():
24
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
25
+
26
+
27
+ @lru_cache()
28
+ def bytes_to_unicode():
29
+ """
30
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
31
+ The reversible bpe codes work on unicode strings.
32
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
33
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
34
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
35
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
36
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
37
+ """
38
+ bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
39
+ cs = bs[:]
40
+ n = 0
41
+ for b in range(2**8):
42
+ if b not in bs:
43
+ bs.append(b)
44
+ cs.append(2**8 + n)
45
+ n += 1
46
+ cs = [chr(n) for n in cs]
47
+ return dict(zip(bs, cs))
48
+
49
+
50
+ def get_pairs(word):
51
+ """Return set of symbol pairs in a word.
52
+ Word is represented as tuple of symbols (symbols being variable-length strings).
53
+ """
54
+ pairs = set()
55
+ prev_char = word[0]
56
+ for char in word[1:]:
57
+ pairs.add((prev_char, char))
58
+ prev_char = char
59
+ return pairs
60
+
61
+
62
+ def basic_clean(text):
63
+ text = ftfy.fix_text(text)
64
+ text = html.unescape(html.unescape(text))
65
+ return text.strip()
66
+
67
+
68
+ def whitespace_clean(text):
69
+ text = re.sub(r"\s+", " ", text)
70
+ text = text.strip()
71
+ return text
72
+
73
+
74
+ class SimpleTokenizer(object):
75
+ def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
76
+ self.byte_encoder = bytes_to_unicode()
77
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
78
+ merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
79
+ merges = merges[1 : 49152 - 256 - 2 + 1]
80
+ merges = [tuple(merge.split()) for merge in merges]
81
+ vocab = list(bytes_to_unicode().values())
82
+ vocab = vocab + [v + "</w>" for v in vocab]
83
+ for merge in merges:
84
+ vocab.append("".join(merge))
85
+ if not special_tokens:
86
+ special_tokens = ["<start_of_text>", "<end_of_text>"]
87
+ else:
88
+ special_tokens = ["<start_of_text>", "<end_of_text>"] + special_tokens
89
+ vocab.extend(special_tokens)
90
+ self.encoder = dict(zip(vocab, range(len(vocab))))
91
+ self.decoder = {v: k for k, v in self.encoder.items()}
92
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
93
+ self.cache = {t: t for t in special_tokens}
94
+ special = "|".join(special_tokens)
95
+ self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
96
+
97
+ self.vocab_size = len(self.encoder)
98
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
99
+
100
+ def bpe(self, token):
101
+ if token in self.cache:
102
+ return self.cache[token]
103
+ word = tuple(token[:-1]) + (token[-1] + "</w>",)
104
+ pairs = get_pairs(word)
105
+
106
+ if not pairs:
107
+ return token + "</w>"
108
+
109
+ while True:
110
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
111
+ if bigram not in self.bpe_ranks:
112
+ break
113
+ first, second = bigram
114
+ new_word = []
115
+ i = 0
116
+ while i < len(word):
117
+ try:
118
+ j = word.index(first, i)
119
+ new_word.extend(word[i:j])
120
+ i = j
121
+ except:
122
+ new_word.extend(word[i:])
123
+ break
124
+
125
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
126
+ new_word.append(first + second)
127
+ i += 2
128
+ else:
129
+ new_word.append(word[i])
130
+ i += 1
131
+ new_word = tuple(new_word)
132
+ word = new_word
133
+ if len(word) == 1:
134
+ break
135
+ else:
136
+ pairs = get_pairs(word)
137
+ word = " ".join(word)
138
+ self.cache[token] = word
139
+ return word
140
+
141
+ def encode(self, text):
142
+ bpe_tokens = []
143
+ text = whitespace_clean(basic_clean(text)).lower()
144
+ for token in re.findall(self.pat, text):
145
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
146
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
147
+ return bpe_tokens
148
+
149
+ def decode(self, tokens):
150
+ text = "".join([self.decoder[token] for token in tokens])
151
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors="replace").replace("</w>", " ")
152
+ return text
153
+
154
+
155
+ _tokenizer = SimpleTokenizer()
156
+
157
+
158
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
159
+ """
160
+ Returns the tokenized representation of given input string(s)
161
+
162
+ Parameters
163
+ ----------
164
+ texts : Union[str, List[str]]
165
+ An input string or a list of input strings to tokenize
166
+ context_length : int
167
+ The context length to use; all CLIP models use 77 as the context length
168
+
169
+ Returns
170
+ -------
171
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
172
+ """
173
+ if isinstance(texts, str):
174
+ texts = [texts]
175
+
176
+ sot_token = _tokenizer.encoder["<start_of_text>"]
177
+ eot_token = _tokenizer.encoder["<end_of_text>"]
178
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
179
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
180
+
181
+ for i, tokens in enumerate(all_tokens):
182
+ if len(tokens) > context_length:
183
+ tokens = tokens[:context_length] # Truncate
184
+ tokens[-1] = eot_token
185
+ result[i, : len(tokens)] = torch.tensor(tokens)
186
+
187
+ return result
188
+
189
+
190
+ class HFTokenizer:
191
+ "HuggingFace tokenizer wrapper"
192
+
193
+ def __init__(self, tokenizer_name: str):
194
+ from transformers import AutoTokenizer
195
+
196
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
197
+
198
+ def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor:
199
+ # same cleaning as for default tokenizer, except lowercasing
200
+ # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
201
+ if isinstance(texts, str):
202
+ texts = [texts]
203
+ texts = [whitespace_clean(basic_clean(text)) for text in texts]
204
+ input_ids = self.tokenizer(texts, return_tensors="pt", max_length=context_length, padding="max_length", truncation=True).input_ids
205
+ return input_ids
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/transform.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Sequence, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision.transforms.functional as F
6
+
7
+ from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, CenterCrop
8
+
9
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
10
+
11
+
12
+ class ResizeMaxSize(nn.Module):
13
+
14
+ def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn="max", fill=0):
15
+ super().__init__()
16
+ if not isinstance(max_size, int):
17
+ raise TypeError(f"Size should be int. Got {type(max_size)}")
18
+ self.max_size = max_size
19
+ self.interpolation = interpolation
20
+ self.fn = min if fn == "min" else min
21
+ self.fill = fill
22
+
23
+ def forward(self, img):
24
+ if isinstance(img, torch.Tensor):
25
+ height, width = img.shape[:2]
26
+ else:
27
+ width, height = img.size
28
+ scale = self.max_size / float(max(height, width))
29
+ if scale != 1.0:
30
+ new_size = tuple(round(dim * scale) for dim in (height, width))
31
+ img = F.resize(img, new_size, self.interpolation)
32
+ pad_h = self.max_size - new_size[0]
33
+ pad_w = self.max_size - new_size[1]
34
+ img = F.pad(img, padding=[pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2], fill=self.fill)
35
+ return img
36
+
37
+
38
+ def _convert_to_rgb(image):
39
+ return image.convert("RGB")
40
+
41
+
42
+ # class CatGen(nn.Module):
43
+ # def __init__(self, num=4):
44
+ # self.num = num
45
+ # def mixgen_batch(image, text):
46
+ # batch_size = image.shape[0]
47
+ # index = np.random.permutation(batch_size)
48
+
49
+ # cat_images = []
50
+ # for i in range(batch_size):
51
+ # # image mixup
52
+ # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:]
53
+ # # text concat
54
+ # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0]
55
+ # text = torch.stack(text)
56
+ # return image, text
57
+
58
+
59
+ def image_transform(
60
+ image_size: int,
61
+ is_train: bool,
62
+ mean: Optional[Tuple[float, ...]] = None,
63
+ std: Optional[Tuple[float, ...]] = None,
64
+ resize_longest_max: bool = False,
65
+ fill_color: int = 0,
66
+ ):
67
+ mean = mean or OPENAI_DATASET_MEAN
68
+ if not isinstance(mean, (list, tuple)):
69
+ mean = (mean,) * 3
70
+
71
+ std = std or OPENAI_DATASET_STD
72
+ if not isinstance(std, (list, tuple)):
73
+ std = (std,) * 3
74
+
75
+ if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
76
+ # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
77
+ image_size = image_size[0]
78
+
79
+ normalize = Normalize(mean=mean, std=std)
80
+ if is_train:
81
+ return Compose(
82
+ [
83
+ RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
84
+ _convert_to_rgb,
85
+ ToTensor(),
86
+ normalize,
87
+ ]
88
+ )
89
+ else:
90
+ if resize_longest_max:
91
+ transforms = [ResizeMaxSize(image_size, fill=fill_color)]
92
+ else:
93
+ transforms = [
94
+ Resize(image_size, interpolation=InterpolationMode.BICUBIC),
95
+ CenterCrop(image_size),
96
+ ]
97
+ transforms.extend(
98
+ [
99
+ _convert_to_rgb,
100
+ ToTensor(),
101
+ normalize,
102
+ ]
103
+ )
104
+ return Compose(transforms)
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/transformer.py ADDED
@@ -0,0 +1,683 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from collections import OrderedDict
4
+ import math
5
+ from typing import Callable, Optional, Sequence
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ try:
12
+ from timm.models.layers import trunc_normal_
13
+ except:
14
+ from timm.layers import trunc_normal_
15
+
16
+ from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
17
+ from .utils import to_2tuple
18
+
19
+ if os.getenv("ENV_TYPE") == "deepspeed":
20
+ try:
21
+ import deepspeed
22
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
23
+ except:
24
+ print("Please 'pip install deepspeed'")
25
+ deepspeed = None
26
+ from torch.utils.checkpoint import checkpoint
27
+ else:
28
+ from torch.utils.checkpoint import checkpoint
29
+
30
+ try:
31
+ import xformers.ops as xops
32
+ except ImportError:
33
+ xops = None
34
+ # print("Please 'pip install xformers'")
35
+
36
+
37
+ class LayerNormFp32(nn.LayerNorm):
38
+ """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
39
+
40
+ def __init__(self, *args, **kwargs):
41
+ super().__init__(*args, **kwargs)
42
+
43
+ def forward(self, x: torch.Tensor):
44
+ output = F.layer_norm(
45
+ x.float(),
46
+ self.normalized_shape,
47
+ self.weight.float() if self.weight is not None else None,
48
+ self.bias.float() if self.bias is not None else None,
49
+ self.eps,
50
+ )
51
+ return output.type_as(x)
52
+
53
+
54
+ class LayerNorm(nn.LayerNorm):
55
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
56
+
57
+ def forward(self, x: torch.Tensor):
58
+ orig_type = x.dtype
59
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
60
+ return x.to(orig_type)
61
+
62
+
63
+ class QuickGELU(nn.Module):
64
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
65
+ def forward(self, x: torch.Tensor):
66
+ return x * torch.sigmoid(1.702 * x)
67
+
68
+
69
+ class LayerScale(nn.Module):
70
+ def __init__(self, dim, init_values=1e-5, inplace=False):
71
+ super().__init__()
72
+ self.inplace = inplace
73
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
74
+
75
+ def forward(self, x):
76
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
77
+
78
+
79
+ class PatchDropout(nn.Module):
80
+ """
81
+ https://arxiv.org/abs/2212.00794
82
+ """
83
+
84
+ def __init__(self, prob, exclude_first_token=True):
85
+ super().__init__()
86
+ assert 0 <= prob < 1.0
87
+ self.prob = prob
88
+ self.exclude_first_token = exclude_first_token # exclude CLS token
89
+ logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
90
+
91
+ def forward(self, x):
92
+ if not self.training or self.prob == 0.0:
93
+ return x
94
+
95
+ if self.exclude_first_token:
96
+ cls_tokens, x = x[:, :1], x[:, 1:]
97
+ else:
98
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
99
+
100
+ batch = x.size()[0]
101
+ num_tokens = x.size()[1]
102
+
103
+ batch_indices = torch.arange(batch)
104
+ batch_indices = batch_indices[..., None]
105
+
106
+ keep_prob = 1 - self.prob
107
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
108
+
109
+ rand = torch.randn(batch, num_tokens)
110
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
111
+
112
+ x = x[batch_indices, patch_indices_keep]
113
+
114
+ if self.exclude_first_token:
115
+ x = torch.cat((cls_tokens, x), dim=1)
116
+
117
+ if self.training and os.getenv("RoPE") == "1":
118
+ return x, patch_indices_keep
119
+
120
+ return x
121
+
122
+
123
+ def _in_projection_packed(
124
+ q: torch.Tensor,
125
+ k: torch.Tensor,
126
+ v: torch.Tensor,
127
+ w: torch.Tensor,
128
+ b: Optional[torch.Tensor] = None,
129
+ ):
130
+ """
131
+ https://github.com/pytorch/pytorch/blob/db2a237763eb8693a20788be94f8c192e762baa8/torch/nn/functional.py#L4726
132
+ """
133
+ E = q.size(-1)
134
+ if k is v:
135
+ if q is k:
136
+ # self-attention
137
+ return F.linear(q, w, b).chunk(3, dim=-1)
138
+ else:
139
+ # encoder-decoder attention
140
+ w_q, w_kv = w.split([E, E * 2])
141
+ if b is None:
142
+ b_q = b_kv = None
143
+ else:
144
+ b_q, b_kv = b.split([E, E * 2])
145
+ return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
146
+ else:
147
+ w_q, w_k, w_v = w.chunk(3)
148
+ if b is None:
149
+ b_q = b_k = b_v = None
150
+ else:
151
+ b_q, b_k, b_v = b.chunk(3)
152
+ return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
153
+
154
+
155
+ class Attention(nn.Module):
156
+ def __init__(self, dim, num_heads=8, qkv_bias=True, scaled_cosine=False, scale_heads=False, logit_scale_max=math.log(1.0 / 0.01), attn_drop=0.0, proj_drop=0.0, xattn=False, rope=False):
157
+ super().__init__()
158
+ self.scaled_cosine = scaled_cosine
159
+ self.scale_heads = scale_heads
160
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
161
+ self.num_heads = num_heads
162
+ self.head_dim = dim // num_heads
163
+ self.scale = self.head_dim**-0.5
164
+ self.logit_scale_max = logit_scale_max
165
+
166
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
167
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
168
+ if qkv_bias:
169
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
170
+ else:
171
+ self.in_proj_bias = None
172
+
173
+ if self.scaled_cosine:
174
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
175
+ else:
176
+ self.logit_scale = None
177
+ self.attn_drop = nn.Dropout(attn_drop)
178
+ if self.scale_heads:
179
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
180
+ else:
181
+ self.head_scale = None
182
+ self.out_proj = nn.Linear(dim, dim)
183
+ self.out_drop = nn.Dropout(proj_drop)
184
+ self.xattn = xattn
185
+ self.xattn_drop = attn_drop
186
+ self.rope = rope
187
+
188
+ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
189
+ L, N, C = x.shape
190
+ q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
191
+ if self.xattn:
192
+ q = q.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
193
+ k = k.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
194
+ v = v.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
195
+
196
+ x = xops.memory_efficient_attention(
197
+ q,
198
+ k,
199
+ v,
200
+ p=self.xattn_drop,
201
+ scale=self.scale if self.logit_scale is None else None,
202
+ attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None,
203
+ )
204
+ else:
205
+ q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
206
+ k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
207
+ v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
208
+
209
+ if self.logit_scale is not None:
210
+ attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
211
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
212
+ attn = attn.view(N, self.num_heads, L, L) * logit_scale
213
+ attn = attn.view(-1, L, L)
214
+ else:
215
+ q = q * self.scale
216
+ attn = torch.bmm(q, k.transpose(-1, -2))
217
+
218
+ if attn_mask is not None:
219
+ if attn_mask.dtype == torch.bool:
220
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
221
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
222
+ attn_mask = new_attn_mask
223
+ attn += attn_mask
224
+
225
+ attn = attn.softmax(dim=-1)
226
+ attn = self.attn_drop(attn)
227
+
228
+ x = torch.bmm(attn, v)
229
+
230
+ if self.head_scale is not None:
231
+ x = x.view(N, self.num_heads, L, C) * self.head_scale
232
+ x = x.view(-1, L, C)
233
+ x = x.transpose(0, 1).reshape(L, N, C)
234
+ x = self.out_proj(x)
235
+ x = self.out_drop(x)
236
+ return x
237
+
238
+
239
+ class CustomAttention(nn.Module):
240
+ def __init__(self, dim, num_heads=8, qkv_bias=True, scaled_cosine=True, scale_heads=False, logit_scale_max=math.log(1.0 / 0.01), attn_drop=0.0, proj_drop=0.0, xattn=False):
241
+ super().__init__()
242
+ self.scaled_cosine = scaled_cosine
243
+ self.scale_heads = scale_heads
244
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
245
+ self.num_heads = num_heads
246
+ self.head_dim = dim // num_heads
247
+ self.scale = self.head_dim**-0.5
248
+ self.logit_scale_max = logit_scale_max
249
+
250
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
251
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
252
+ if qkv_bias:
253
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
254
+ else:
255
+ self.in_proj_bias = None
256
+
257
+ if self.scaled_cosine:
258
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
259
+ else:
260
+ self.logit_scale = None
261
+ self.attn_drop = nn.Dropout(attn_drop)
262
+ if self.scale_heads:
263
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
264
+ else:
265
+ self.head_scale = None
266
+ self.out_proj = nn.Linear(dim, dim)
267
+ self.out_drop = nn.Dropout(proj_drop)
268
+ self.xattn = xattn
269
+ self.xattn_drop = attn_drop
270
+
271
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
272
+ q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
273
+ N_q, B_q, C_q = q.shape
274
+ N_k, B_k, C_k = k.shape
275
+ N_v, B_v, C_v = v.shape
276
+ if self.xattn:
277
+ # B, N, C -> B, N, num_heads, C
278
+ q = q.permute(1, 0, 2).reshape(B_q, N_q, self.num_heads, -1)
279
+ k = k.permute(1, 0, 2).reshape(B_k, N_k, self.num_heads, -1)
280
+ v = v.permute(1, 0, 2).reshape(B_v, N_v, self.num_heads, -1)
281
+
282
+ x = xops.memory_efficient_attention(q, k, v, p=self.xattn_drop, scale=self.scale if self.logit_scale is None else None, attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None)
283
+ else:
284
+ # B*H, L, C
285
+ q = q.contiguous().view(N_q, B_q * self.num_heads, -1).transpose(0, 1)
286
+ k = k.contiguous().view(N_k, B_k * self.num_heads, -1).transpose(0, 1)
287
+ v = v.contiguous().view(N_v, B_v * self.num_heads, -1).transpose(0, 1)
288
+
289
+ if self.logit_scale is not None:
290
+ # B*H, N_q, N_k
291
+ attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
292
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
293
+ attn = attn.view(B_q, self.num_heads, N_q, N_k) * logit_scale
294
+ attn = attn.view(-1, N_q, N_k)
295
+ else:
296
+ q = q * self.scale
297
+ attn = torch.bmm(q, k.transpose(-1, -2))
298
+
299
+ if attn_mask is not None:
300
+ if attn_mask.dtype == torch.bool:
301
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
302
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
303
+ attn_mask = new_attn_mask
304
+ attn += attn_mask
305
+
306
+ attn = attn.softmax(dim=-1)
307
+ attn = self.attn_drop(attn)
308
+
309
+ x = torch.bmm(attn, v)
310
+
311
+ if self.head_scale is not None:
312
+ x = x.view(B_q, self.num_heads, N_q, C_q) * self.head_scale
313
+ x = x.view(-1, N_q, C_q)
314
+ x = x.transpose(0, 1).reshape(N_q, B_q, C_q)
315
+ x = self.out_proj(x)
316
+ x = self.out_drop(x)
317
+ return x
318
+
319
+
320
+ class CustomResidualAttentionBlock(nn.Module):
321
+ def __init__(
322
+ self,
323
+ d_model: int,
324
+ n_head: int,
325
+ mlp_ratio: float = 4.0,
326
+ ls_init_value: float = None,
327
+ act_layer: Callable = nn.GELU,
328
+ norm_layer: Callable = LayerNorm,
329
+ scale_cosine_attn: bool = False,
330
+ scale_heads: bool = False,
331
+ scale_attn: bool = False,
332
+ scale_fc: bool = False,
333
+ cross_attn: bool = False,
334
+ xattn: bool = False,
335
+ ):
336
+ super().__init__()
337
+
338
+ self.ln_1 = norm_layer(d_model)
339
+ self.ln_1_k = norm_layer(d_model) if cross_attn else self.ln_1
340
+ self.ln_1_v = norm_layer(d_model) if cross_attn else self.ln_1
341
+ self.attn = CustomAttention(d_model, n_head, qkv_bias=True, attn_drop=0.0, proj_drop=0.0, scaled_cosine=scale_cosine_attn, scale_heads=scale_heads, xattn=xattn)
342
+
343
+ self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
344
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
345
+
346
+ self.ln_2 = norm_layer(d_model)
347
+ mlp_width = int(d_model * mlp_ratio)
348
+ self.mlp = nn.Sequential(OrderedDict([("c_fc", nn.Linear(d_model, mlp_width)), ("ln", norm_layer(mlp_width) if scale_fc else nn.Identity()), ("gelu", act_layer()), ("c_proj", nn.Linear(mlp_width, d_model))]))
349
+
350
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
351
+
352
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
353
+ q = q + self.ls_1(self.ln_attn(self.attn(self.ln_1(q), self.ln_1_k(k), self.ln_1_v(v), attn_mask=attn_mask)))
354
+ q = q + self.ls_2(self.mlp(self.ln_2(q)))
355
+ return q
356
+
357
+
358
+ class CustomTransformer(nn.Module):
359
+ def __init__(
360
+ self,
361
+ width: int,
362
+ layers: int,
363
+ heads: int,
364
+ mlp_ratio: float = 4.0,
365
+ ls_init_value: float = None,
366
+ act_layer: Callable = nn.GELU,
367
+ norm_layer: Callable = LayerNorm,
368
+ scale_cosine_attn: bool = True,
369
+ scale_heads: bool = False,
370
+ scale_attn: bool = False,
371
+ scale_fc: bool = False,
372
+ cross_attn: bool = False,
373
+ xattn: bool = False,
374
+ ):
375
+ super().__init__()
376
+ self.width = width
377
+ self.layers = layers
378
+ self.grad_checkpointing = False
379
+ self.xattn = xattn
380
+
381
+ self.resblocks = nn.ModuleList(
382
+ [
383
+ CustomResidualAttentionBlock(
384
+ width,
385
+ heads,
386
+ mlp_ratio,
387
+ ls_init_value=ls_init_value,
388
+ act_layer=act_layer,
389
+ norm_layer=norm_layer,
390
+ scale_cosine_attn=scale_cosine_attn,
391
+ scale_heads=scale_heads,
392
+ scale_attn=scale_attn,
393
+ scale_fc=scale_fc,
394
+ cross_attn=cross_attn,
395
+ xattn=xattn,
396
+ )
397
+ for _ in range(layers)
398
+ ]
399
+ )
400
+
401
+ def get_cast_dtype(self) -> torch.dtype:
402
+ return self.resblocks[0].mlp.c_fc.weight.dtype
403
+
404
+ def forward(self, q: torch.Tensor, k: torch.Tensor = None, v: torch.Tensor = None, attn_mask: Optional[torch.Tensor] = None):
405
+ if k is None and v is None:
406
+ k = v = q
407
+ for r in self.resblocks:
408
+ if self.grad_checkpointing and not torch.jit.is_scripting():
409
+ q = checkpoint(r, q, k, v, attn_mask)
410
+ else:
411
+ q = r(q, k, v, attn_mask=attn_mask)
412
+ return q
413
+
414
+
415
+ class ResidualAttentionBlock(nn.Module):
416
+ def __init__(
417
+ self,
418
+ d_model: int,
419
+ n_head: int,
420
+ mlp_ratio: float = 4.0,
421
+ ls_init_value: float = None,
422
+ act_layer: Callable = nn.GELU,
423
+ norm_layer: Callable = LayerNorm,
424
+ xattn: bool = False,
425
+ ):
426
+ super().__init__()
427
+
428
+ self.ln_1 = norm_layer(d_model)
429
+ if xattn:
430
+ self.attn = Attention(d_model, n_head, xattn=True)
431
+ else:
432
+ self.attn = nn.MultiheadAttention(d_model, n_head)
433
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
434
+
435
+ self.ln_2 = norm_layer(d_model)
436
+ mlp_width = int(d_model * mlp_ratio)
437
+ self.mlp = nn.Sequential(OrderedDict([("c_fc", nn.Linear(d_model, mlp_width)), ("gelu", act_layer()), ("c_proj", nn.Linear(mlp_width, d_model))]))
438
+
439
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
440
+ self.xattn = xattn
441
+
442
+ def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
443
+ attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
444
+ if self.xattn:
445
+ return self.attn(x, attn_mask=attn_mask)
446
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
447
+
448
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
449
+ x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask))
450
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
451
+ return x
452
+
453
+
454
+ class Transformer(nn.Module):
455
+ def __init__(
456
+ self,
457
+ width: int,
458
+ layers: int,
459
+ heads: int,
460
+ mlp_ratio: float = 4.0,
461
+ ls_init_value: float = None,
462
+ act_layer: Callable = nn.GELU,
463
+ norm_layer: Callable = LayerNorm,
464
+ xattn: bool = False,
465
+ ):
466
+ super().__init__()
467
+ self.width = width
468
+ self.layers = layers
469
+ self.grad_checkpointing = False
470
+
471
+ self.resblocks = nn.ModuleList([ResidualAttentionBlock(width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn) for _ in range(layers)])
472
+
473
+ def get_cast_dtype(self) -> torch.dtype:
474
+ return self.resblocks[0].mlp.c_fc.weight.dtype
475
+
476
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
477
+ for r in self.resblocks:
478
+ if self.grad_checkpointing and not torch.jit.is_scripting():
479
+ x = checkpoint(r, x, attn_mask)
480
+ else:
481
+ x = r(x, attn_mask=attn_mask)
482
+ return x
483
+
484
+
485
+ class VisionTransformer(nn.Module):
486
+ def __init__(
487
+ self,
488
+ image_size: int,
489
+ patch_size: int,
490
+ width: int,
491
+ layers: int,
492
+ heads: int,
493
+ mlp_ratio: float,
494
+ ls_init_value: float = None,
495
+ patch_dropout: float = 0.0,
496
+ global_average_pool: bool = False,
497
+ output_dim: int = 512,
498
+ act_layer: Callable = nn.GELU,
499
+ norm_layer: Callable = LayerNorm,
500
+ xattn: bool = False,
501
+ ):
502
+ super().__init__()
503
+ self.image_size = to_2tuple(image_size)
504
+ self.patch_size = to_2tuple(patch_size)
505
+ self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])
506
+ self.output_dim = output_dim
507
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
508
+
509
+ scale = width**-0.5
510
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
511
+ self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
512
+
513
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
514
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity()
515
+ self.ln_pre = norm_layer(width)
516
+
517
+ self.transformer = Transformer(width, layers, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn)
518
+
519
+ self.global_average_pool = global_average_pool
520
+ self.ln_post = norm_layer(width)
521
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
522
+
523
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
524
+ for param in self.parameters():
525
+ param.requires_grad = False
526
+
527
+ if unlocked_groups != 0:
528
+ groups = [
529
+ [
530
+ self.conv1,
531
+ self.class_embedding,
532
+ self.positional_embedding,
533
+ self.ln_pre,
534
+ ],
535
+ *self.transformer.resblocks[:-1],
536
+ [
537
+ self.transformer.resblocks[-1],
538
+ self.ln_post,
539
+ ],
540
+ self.proj,
541
+ ]
542
+
543
+ def _unlock(x):
544
+ if isinstance(x, Sequence):
545
+ for g in x:
546
+ _unlock(g)
547
+ else:
548
+ if isinstance(x, torch.nn.Parameter):
549
+ x.requires_grad = True
550
+ else:
551
+ for p in x.parameters():
552
+ p.requires_grad = True
553
+
554
+ _unlock(groups[-unlocked_groups:])
555
+
556
+ def get_num_layers(self):
557
+ return self.transformer.layers
558
+
559
+ @torch.jit.ignore
560
+ def set_grad_checkpointing(self, enable=True):
561
+ self.transformer.grad_checkpointing = enable
562
+
563
+ @torch.jit.ignore
564
+ def no_weight_decay(self):
565
+ return {"positional_embedding", "class_embedding"}
566
+
567
+ def forward(self, x: torch.Tensor, return_all_features: bool = False):
568
+ x = self.conv1(x) # shape = [*, width, grid, grid]
569
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
570
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
571
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
572
+ x = x + self.positional_embedding.to(x.dtype)
573
+
574
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
575
+ x = self.patch_dropout(x)
576
+ x = self.ln_pre(x)
577
+
578
+ x = x.permute(1, 0, 2) # NLD -> LND
579
+ x = self.transformer(x)
580
+ x = x.permute(1, 0, 2) # LND -> NLD
581
+
582
+ if not return_all_features:
583
+ if self.global_average_pool:
584
+ x = x.mean(dim=1) # x = x[:,1:,:].mean(dim=1)
585
+ else:
586
+ x = x[:, 0]
587
+
588
+ x = self.ln_post(x)
589
+
590
+ if self.proj is not None:
591
+ x = x @ self.proj
592
+
593
+ return x
594
+
595
+
596
+ class TextTransformer(nn.Module):
597
+ def __init__(
598
+ self,
599
+ context_length: int = 77,
600
+ vocab_size: int = 49408,
601
+ width: int = 512,
602
+ heads: int = 8,
603
+ layers: int = 12,
604
+ ls_init_value: float = None,
605
+ output_dim: int = 512,
606
+ act_layer: Callable = nn.GELU,
607
+ norm_layer: Callable = LayerNorm,
608
+ xattn: bool = False,
609
+ attn_mask: bool = True,
610
+ ):
611
+ super().__init__()
612
+ self.context_length = context_length
613
+ self.vocab_size = vocab_size
614
+ self.width = width
615
+ self.output_dim = output_dim
616
+
617
+ self.token_embedding = nn.Embedding(vocab_size, width)
618
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
619
+ self.transformer = Transformer(width=width, layers=layers, heads=heads, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn)
620
+
621
+ self.xattn = xattn
622
+ self.ln_final = norm_layer(width)
623
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
624
+
625
+ if attn_mask:
626
+ self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
627
+ else:
628
+ self.attn_mask = None
629
+
630
+ self.init_parameters()
631
+
632
+ def init_parameters(self):
633
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
634
+ nn.init.normal_(self.positional_embedding, std=0.01)
635
+
636
+ proj_std = (self.transformer.width**-0.5) * ((2 * self.transformer.layers) ** -0.5)
637
+ attn_std = self.transformer.width**-0.5
638
+ fc_std = (2 * self.transformer.width) ** -0.5
639
+ for block in self.transformer.resblocks:
640
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
641
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
642
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
643
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
644
+
645
+ if self.text_projection is not None:
646
+ nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5)
647
+
648
+ @torch.jit.ignore
649
+ def set_grad_checkpointing(self, enable=True):
650
+ self.transformer.grad_checkpointing = enable
651
+
652
+ @torch.jit.ignore
653
+ def no_weight_decay(self):
654
+ # return {'positional_embedding', 'token_embedding'}
655
+ return {"positional_embedding"}
656
+
657
+ def get_num_layers(self):
658
+ return self.transformer.layers
659
+
660
+ def build_attention_mask(self):
661
+ # lazily create causal attention mask, with full attention between the vision tokens
662
+ # pytorch uses additive attention mask; fill with -inf
663
+ mask = torch.empty(self.context_length, self.context_length)
664
+ mask.fill_(float("-inf"))
665
+ mask.triu_(1) # zero out the lower diagonal
666
+ return mask
667
+
668
+ def forward(self, text, return_all_features: bool = False):
669
+ cast_dtype = self.transformer.get_cast_dtype()
670
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
671
+
672
+ x = x + self.positional_embedding.to(cast_dtype)
673
+ x = x.permute(1, 0, 2) # NLD -> LND
674
+ x = self.transformer(x, attn_mask=self.attn_mask)
675
+ # x = self.transformer(x) # no attention mask is applied
676
+ x = x.permute(1, 0, 2) # LND -> NLD
677
+ x = self.ln_final(x)
678
+
679
+ if not return_all_features:
680
+ # x.shape = [batch_size, n_ctx, transformer.width]
681
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
682
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
683
+ return x
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/utils.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import repeat
2
+ import collections.abc
3
+ import logging
4
+ import math
5
+ import numpy as np
6
+
7
+ import torch
8
+ from torch import nn as nn
9
+ from torchvision.ops.misc import FrozenBatchNorm2d
10
+ import torch.nn.functional as F
11
+
12
+
13
+ # open CLIP
14
+ def resize_clip_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
15
+ # Rescale the grid of position embeddings when loading from state_dict
16
+ old_pos_embed = state_dict.get("visual.positional_embedding", None)
17
+ if old_pos_embed is None or not hasattr(model.visual, "grid_size"):
18
+ return
19
+ grid_size = to_2tuple(model.visual.grid_size)
20
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
21
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
22
+ if new_seq_len == old_pos_embed.shape[0]:
23
+ return
24
+
25
+ if extra_tokens:
26
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
27
+ else:
28
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
29
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
30
+
31
+ logging.info("Resizing position embedding grid-size from %s to %s", old_grid_size, grid_size)
32
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
33
+ pos_emb_img = F.interpolate(
34
+ pos_emb_img,
35
+ size=grid_size,
36
+ mode=interpolation,
37
+ align_corners=True,
38
+ )
39
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
40
+ if pos_emb_tok is not None:
41
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
42
+ else:
43
+ new_pos_embed = pos_emb_img
44
+ state_dict["visual.positional_embedding"] = new_pos_embed
45
+
46
+
47
+ def resize_visual_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
48
+ # Rescale the grid of position embeddings when loading from state_dict
49
+ old_pos_embed = state_dict.get("positional_embedding", None)
50
+ if old_pos_embed is None or not hasattr(model.visual, "grid_size"):
51
+ return
52
+ grid_size = to_2tuple(model.visual.grid_size)
53
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
54
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
55
+ if new_seq_len == old_pos_embed.shape[0]:
56
+ return
57
+
58
+ if extra_tokens:
59
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
60
+ else:
61
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
62
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
63
+
64
+ logging.info("Resizing position embedding grid-size from %s to %s", old_grid_size, grid_size)
65
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
66
+ pos_emb_img = F.interpolate(
67
+ pos_emb_img,
68
+ size=grid_size,
69
+ mode=interpolation,
70
+ align_corners=True,
71
+ )
72
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
73
+ if pos_emb_tok is not None:
74
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
75
+ else:
76
+ new_pos_embed = pos_emb_img
77
+ state_dict["positional_embedding"] = new_pos_embed
78
+
79
+
80
+ def resize_evaclip_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
81
+ all_keys = list(state_dict.keys())
82
+ # interpolate position embedding
83
+ if "visual.pos_embed" in state_dict:
84
+ pos_embed_checkpoint = state_dict["visual.pos_embed"]
85
+ embedding_size = pos_embed_checkpoint.shape[-1]
86
+ num_patches = model.visual.patch_embed.num_patches
87
+ # num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
88
+ num_extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
89
+ # height (== width) for the checkpoint position embedding
90
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
91
+ # height (== width) for the new position embedding
92
+ new_size = int(num_patches**0.5)
93
+ # class_token and dist_token are kept unchanged
94
+ if orig_size != new_size:
95
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
96
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
97
+ # only the position tokens are interpolated
98
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
99
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
100
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False)
101
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
102
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
103
+ state_dict["visual.pos_embed"] = new_pos_embed
104
+
105
+ patch_embed_proj = state_dict["visual.patch_embed.proj.weight"]
106
+ patch_size = model.visual.patch_embed.patch_size
107
+ state_dict["visual.patch_embed.proj.weight"] = torch.nn.functional.interpolate(patch_embed_proj.float(), size=patch_size, mode="bicubic", align_corners=False)
108
+
109
+
110
+ def resize_eva_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
111
+ all_keys = list(state_dict.keys())
112
+ # interpolate position embedding
113
+ if "pos_embed" in state_dict:
114
+ pos_embed_checkpoint = state_dict["pos_embed"]
115
+ embedding_size = pos_embed_checkpoint.shape[-1]
116
+ num_patches = model.visual.patch_embed.num_patches
117
+ # num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
118
+ num_extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
119
+ # height (== width) for the checkpoint position embedding
120
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
121
+ # height (== width) for the new position embedding
122
+ new_size = int(num_patches**0.5)
123
+ # class_token and dist_token are kept unchanged
124
+ if orig_size != new_size:
125
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
126
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
127
+ # only the position tokens are interpolated
128
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
129
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
130
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False)
131
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
132
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
133
+ state_dict["pos_embed"] = new_pos_embed
134
+
135
+ patch_embed_proj = state_dict["patch_embed.proj.weight"]
136
+ patch_size = model.visual.patch_embed.patch_size
137
+ state_dict["patch_embed.proj.weight"] = torch.nn.functional.interpolate(patch_embed_proj.float(), size=patch_size, mode="bicubic", align_corners=False)
138
+
139
+
140
+ def resize_rel_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
141
+ all_keys = list(state_dict.keys())
142
+ for key in all_keys:
143
+ if "relative_position_index" in key:
144
+ state_dict.pop(key)
145
+
146
+ if "relative_position_bias_table" in key:
147
+ rel_pos_bias = state_dict[key]
148
+ src_num_pos, num_attn_heads = rel_pos_bias.size()
149
+ dst_num_pos, _ = model.visual.state_dict()[key].size()
150
+ dst_patch_shape = model.visual.patch_embed.patch_shape
151
+ if dst_patch_shape[0] != dst_patch_shape[1]:
152
+ raise NotImplementedError()
153
+ num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
154
+ src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
155
+ dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
156
+ if src_size != dst_size:
157
+ print("Position interpolate for %s from %dx%d to %dx%d" % (key, src_size, src_size, dst_size, dst_size))
158
+ extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
159
+ rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
160
+
161
+ def geometric_progression(a, r, n):
162
+ return a * (1.0 - r**n) / (1.0 - r)
163
+
164
+ left, right = 1.01, 1.5
165
+ while right - left > 1e-6:
166
+ q = (left + right) / 2.0
167
+ gp = geometric_progression(1, q, src_size // 2)
168
+ if gp > dst_size // 2:
169
+ right = q
170
+ else:
171
+ left = q
172
+
173
+ # if q > 1.090307:
174
+ # q = 1.090307
175
+
176
+ dis = []
177
+ cur = 1
178
+ for i in range(src_size // 2):
179
+ dis.append(cur)
180
+ cur += q ** (i + 1)
181
+
182
+ r_ids = [-_ for _ in reversed(dis)]
183
+
184
+ x = r_ids + [0] + dis
185
+ y = r_ids + [0] + dis
186
+
187
+ t = dst_size // 2.0
188
+ dx = np.arange(-t, t + 0.1, 1.0)
189
+ dy = np.arange(-t, t + 0.1, 1.0)
190
+
191
+ print("Original positions = %s" % str(x))
192
+ print("Target positions = %s" % str(dx))
193
+
194
+ all_rel_pos_bias = []
195
+
196
+ for i in range(num_attn_heads):
197
+ z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
198
+ f = F.interpolate.interp2d(x, y, z, kind="cubic")
199
+ all_rel_pos_bias.append(torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
200
+
201
+ rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
202
+
203
+ new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
204
+ state_dict[key] = new_rel_pos_bias
205
+
206
+ # interpolate position embedding
207
+ if "pos_embed" in state_dict:
208
+ pos_embed_checkpoint = state_dict["pos_embed"]
209
+ embedding_size = pos_embed_checkpoint.shape[-1]
210
+ num_patches = model.visual.patch_embed.num_patches
211
+ num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
212
+ # height (== width) for the checkpoint position embedding
213
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
214
+ # height (== width) for the new position embedding
215
+ new_size = int(num_patches**0.5)
216
+ # class_token and dist_token are kept unchanged
217
+ if orig_size != new_size:
218
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
219
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
220
+ # only the position tokens are interpolated
221
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
222
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
223
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False)
224
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
225
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
226
+ state_dict["pos_embed"] = new_pos_embed
227
+
228
+ patch_embed_proj = state_dict["patch_embed.proj.weight"]
229
+ patch_size = model.visual.patch_embed.patch_size
230
+ state_dict["patch_embed.proj.weight"] = torch.nn.functional.interpolate(patch_embed_proj.float(), size=patch_size, mode="bicubic", align_corners=False)
231
+
232
+
233
+ def freeze_batch_norm_2d(module, module_match={}, name=""):
234
+ """
235
+ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
236
+ itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
237
+ returned. Otherwise, the module is walked recursively and submodules are converted in place.
238
+
239
+ Args:
240
+ module (torch.nn.Module): Any PyTorch module.
241
+ module_match (dict): Dictionary of full module names to freeze (all if empty)
242
+ name (str): Full module name (prefix)
243
+
244
+ Returns:
245
+ torch.nn.Module: Resulting module
246
+
247
+ Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
248
+ """
249
+ res = module
250
+ is_match = True
251
+ if module_match:
252
+ is_match = name in module_match
253
+ if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
254
+ res = FrozenBatchNorm2d(module.num_features)
255
+ res.num_features = module.num_features
256
+ res.affine = module.affine
257
+ if module.affine:
258
+ res.weight.data = module.weight.data.clone().detach()
259
+ res.bias.data = module.bias.data.clone().detach()
260
+ res.running_mean.data = module.running_mean.data
261
+ res.running_var.data = module.running_var.data
262
+ res.eps = module.eps
263
+ else:
264
+ for child_name, child in module.named_children():
265
+ full_child_name = ".".join([name, child_name]) if name else child_name
266
+ new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
267
+ if new_child is not child:
268
+ res.add_module(child_name, new_child)
269
+ return res
270
+
271
+
272
+ # From PyTorch internals
273
+ def _ntuple(n):
274
+ def parse(x):
275
+ if isinstance(x, collections.abc.Iterable):
276
+ return x
277
+ return tuple(repeat(x, n))
278
+
279
+ return parse
280
+
281
+
282
+ to_1tuple = _ntuple(1)
283
+ to_2tuple = _ntuple(2)
284
+ to_3tuple = _ntuple(3)
285
+ to_4tuple = _ntuple(4)
286
+ to_ntuple = lambda n, x: _ntuple(n)(x)
287
+
288
+
289
+ def is_logging(args):
290
+ def is_global_master(args):
291
+ return args.rank == 0
292
+
293
+ def is_local_master(args):
294
+ return args.local_rank == 0
295
+
296
+ def is_master(args, local=False):
297
+ return is_local_master(args) if local else is_global_master(args)
298
+
299
+ return is_master
300
+
301
+
302
+ class AllGather(torch.autograd.Function):
303
+ """An autograd function that performs allgather on a tensor.
304
+ Performs all_gather operation on the provided tensors.
305
+ *** Warning ***: torch.distributed.all_gather has no gradient.
306
+ """
307
+
308
+ @staticmethod
309
+ def forward(ctx, tensor, rank, world_size):
310
+ tensors_gather = [torch.empty_like(tensor) for _ in range(world_size)]
311
+ torch.distributed.all_gather(tensors_gather, tensor)
312
+ ctx.rank = rank
313
+ ctx.batch_size = tensor.shape[0]
314
+ return torch.cat(tensors_gather, 0)
315
+
316
+ @staticmethod
317
+ def backward(ctx, grad_output):
318
+ return (grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)], None, None)
319
+
320
+
321
+ allgather = AllGather.apply
blip3o/model/multimodal_encoder/dev_eva_clip/eva_vit.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on EVA, BEIT, timm and DeiT code bases
2
+ # https://github.com/baaivision/EVA
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
4
+ # https://github.com/microsoft/unilm/tree/master/beit
5
+ # https://github.com/facebookresearch/deit/
6
+ # https://github.com/facebookresearch/dino
7
+ # --------------------------------------------------------'
8
+ # not tested yet
9
+ import math
10
+ from transformers import CLIPImageProcessor
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import torch.utils.checkpoint as checkpoint
16
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
17
+ from .eva_clip import create_model_and_transforms, get_model_config
18
+ import torch
19
+ import torchvision
20
+ import time
21
+
22
+
23
+
24
+ class EvaViTWrapper(nn.Module):
25
+ def __init__(self, vision_tower, args, delay_load=False):
26
+ super().__init__()
27
+
28
+ self.is_loaded = False
29
+ self.vision_tower_name = vision_tower
30
+ self.pretrained = args.vision_tower_pretrained
31
+ self.args = args
32
+
33
+ self.select_layer = args.mm_vision_select_layer
34
+ if self.select_layer < -1:
35
+ self.select_layer += 1
36
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
37
+
38
+ self.model_config = get_model_config(self.vision_tower_name)
39
+
40
+ if not delay_load:
41
+ print(f"Loading vision tower: {vision_tower}")
42
+ self.load_model()
43
+ elif getattr(args, "unfreeze_mm_vision_tower", False):
44
+ # TODO: better detector is needed.
45
+ print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
46
+ self.load_model()
47
+ elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
48
+ print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
49
+ self.load_model()
50
+
51
+ def load_model(self):
52
+ print(f"Loading: {self.vision_tower_name}")
53
+ print(f"Pretrained: {self.pretrained}")
54
+ time_start = time.time()
55
+ model, _, image_processor = create_model_and_transforms(self.vision_tower_name, self.pretrained, force_custom_clip=True, precision="fp16")
56
+ time_end = time.time()
57
+ print(f"Loaded: {self.vision_tower_name} in {time_end - time_start:.2f}s")
58
+ self.device = next(model.parameters()).device
59
+ self.dtype = next(model.parameters()).dtype
60
+ if self.device.type != "meta":
61
+ model = model.to("cuda")
62
+ self.vision_tower = model.visual
63
+ resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0]
64
+ normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0]
65
+ self.resize_transform_size = resize_transform.size
66
+ self.image_processor = CLIPImageProcessor.from_pretrained(
67
+ "openai/clip-vit-large-patch14",
68
+ crop_size=resize_transform.size,
69
+ size={"shortest_edge": resize_transform.size},
70
+ image_mean=list(normalize_transform.mean),
71
+ image_std=list(normalize_transform.std),
72
+ )
73
+ print(f"Loaded image processor: {self.image_processor}")
74
+ self.vision_tower.requires_grad_(False)
75
+ self.is_loaded = True
76
+
77
+ def feature_select(self, image_features):
78
+ select_feature_type = self.select_feature
79
+
80
+ # if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
81
+ # select_every_k_layer = len(image_features) // 4
82
+ # image_features = torch.cat([image_features[i] for i in range(select_every_k_layer + self.select_layer, len(image_features), select_every_k_layer)], dim=-1)
83
+ # select_feature_type = select_feature_type.replace("slicefour_", "")
84
+ # elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]:
85
+ # select_layers = [-1, -4, -7, -10, 6]
86
+ # image_features = torch.cat([image_features[i] for i in select_layers], dim=-1)
87
+ # select_feature_type = select_feature_type.replace("slice_m25811_f6_", "")
88
+ # else:
89
+ # image_features = image_features[self.select_layer]
90
+
91
+ if select_feature_type == "patch":
92
+ image_features = image_features[:, 1:]
93
+ elif select_feature_type == "cls_patch":
94
+ image_features = image_features
95
+ else:
96
+ raise ValueError(f"Unexpected select feature: {select_feature_type}")
97
+ return image_features
98
+
99
+ def train(self, mode=True):
100
+ self.training = mode
101
+
102
+ if self.is_loaded:
103
+ self.vision_tower.eval()
104
+
105
+ def forward(self, images):
106
+ if type(images) is list:
107
+ image_features = []
108
+ for image in images:
109
+ image_features = self.vision_tower.forward_features(image.to(self.dtype), return_all_features=True)
110
+ image_features = self.feature_select(image_features).to(self.dtype)
111
+ image_features.append(image_features)
112
+ else:
113
+ image_features = self.vision_tower.forward_features(images.to(self.dtype), return_all_features=True)
114
+ image_features = self.feature_select(image_features).to(self.dtype)
115
+
116
+ return image_features
117
+
118
+ @property
119
+ def dummy_feature(self):
120
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
121
+
122
+ @property
123
+ def hidden_size(self):
124
+ return self.model_config["vision_cfg"]["width"]
125
+
126
+ @property
127
+ def num_patches(self):
128
+ return (self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]) ** 2
129
+
130
+ @property
131
+ def num_patches_per_side(self):
132
+ return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]
133
+
134
+ @property
135
+ def config(self):
136
+ return self.model_config
137
+
138
+ @property
139
+ def image_size(self):
140
+ return self.model_config["vision_cfg"]["image_size"]
blip3o/model/multimodal_encoder/eva_clip/eva_clip_encoder.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .eva_clip_processors import EvaClipImageTrainProcessor
5
+ from .eva_vit import EVAEncoderWrapper
6
+ from .factory import list_models, add_model_config, get_model_config
7
+
8
+
9
+
10
+ class EvaClipVisionTower(nn.Module):
11
+ def __init__(self, vision_tower, args, delay_load=False):
12
+ super().__init__()
13
+
14
+ self.is_loaded = False
15
+ self.vision_tower_name = vision_tower
16
+ self.vision_tower_pretrained = args.vision_tower_pretrained
17
+ self.config = get_model_config(vision_tower)
18
+
19
+
20
+ if not delay_load:
21
+ print(f"Loading EVA ViT: {self.vision_tower_name}")
22
+ self.load_model()
23
+ elif getattr(args, "unfreeze_mm_vision_tower", False):
24
+ # TODO: better detector is needed.
25
+ print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
26
+ self.load_model()
27
+ elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
28
+ print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
29
+ self.load_model()
30
+ else:
31
+ self.cfg_only = self.config
32
+
33
+
34
+ def load_model(self, device_map=None):
35
+ print(f"Pretrained: {self.vision_tower_pretrained}")
36
+ self.image_processor = EvaClipImageTrainProcessor(self.config["vision_cfg"]["image_size"])
37
+ self.vision_tower = EVAEncoderWrapper(self.vision_tower_pretrained, self.config)
38
+ print(f"Loaded image processor: {self.image_processor}")
39
+ self.vision_tower.requires_grad_(False)
40
+ self.is_loaded = True
41
+
42
+ def forward(self, images):
43
+ if type(images) is list:
44
+ image_features = []
45
+ for image in images:
46
+ image_feature = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)).to(image.dtype)
47
+ image_features.append(image_feature)
48
+ else:
49
+ image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(images.dtype)
50
+
51
+ return image_features
52
+
53
+ @property
54
+ def dtype(self):
55
+ return self.vision_tower.dtype
56
+
57
+ @property
58
+ def device(self):
59
+ return self.vision_tower.device
60
+
61
+ @property
62
+ def hidden_size(self):
63
+ return self.config["vision_cfg"]["width"]
64
+
65
+ @property
66
+ def num_patches(self):
67
+ return (self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"]) ** 2
68
+
69
+ @property
70
+ def num_patches_per_side(self):
71
+ return self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"]
72
+
73
+ @property
74
+ def image_size(self):
75
+ return self.config["vision_cfg"]["image_size"]
blip3o/model/multimodal_encoder/eva_clip/eva_clip_processors.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP
3
+ """
4
+
5
+ from torchvision import transforms
6
+ from torchvision.transforms.functional import InterpolationMode
7
+ from transformers.image_processing_utils import BatchFeature
8
+ from PIL import Image
9
+ from transformers.image_transforms import convert_to_rgb
10
+ import numpy as np
11
+
12
+ class BaseProcessor:
13
+ def __init__(self):
14
+ self.transform = lambda x: x
15
+ return
16
+
17
+ def __call__(self, item):
18
+ return self.transform(item)
19
+
20
+
21
+ class EvaClipImageBaseProcessor(BaseProcessor):
22
+ def __init__(self, mean=None, std=None):
23
+ self.mean = (0.48145466, 0.4578275, 0.40821073) if mean is None else mean
24
+ self.std = (0.26862954, 0.26130258, 0.27577711) if std is None else std
25
+ self.normalize = transforms.Normalize(self.mean, self.std)
26
+
27
+
28
+ @property
29
+ def image_mean(self):
30
+ return self.mean
31
+
32
+
33
+ class EvaClipImageTrainProcessor(EvaClipImageBaseProcessor):
34
+ def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
35
+ super().__init__(mean=mean, std=std)
36
+
37
+ self.transform = transforms.Compose(
38
+ [
39
+ convert_to_rgb,
40
+ transforms.Resize(
41
+ image_size,
42
+ interpolation=InterpolationMode.BICUBIC,
43
+ ),
44
+ transforms.CenterCrop(image_size),
45
+ transforms.ToTensor(),
46
+ self.normalize,
47
+ ]
48
+ )
49
+
50
+
51
+ self.image_size = image_size
52
+
53
+ def preprocess(self, images, return_tensors):
54
+ if isinstance(images, Image.Image):
55
+ images = [images]
56
+ else:
57
+ assert isinstance(images, list)
58
+
59
+ transformed_images = [self.transform(image).numpy() for image in images]
60
+ data = {"pixel_values": transformed_images}
61
+
62
+
63
+ return BatchFeature(data=data, tensor_type=return_tensors)
64
+
65
+ def __call__(self, item):
66
+ return self.transform(item)
67
+
68
+ @property
69
+ def crop_size(self):
70
+ return {"height": self.image_size, "width": self.image_size}
71
+
72
+ @property
73
+ def size(self):
74
+ return {"shortest_edge": self.image_size}
blip3o/model/multimodal_encoder/eva_clip/eva_vit.py ADDED
@@ -0,0 +1,762 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP
3
+ """
4
+
5
+ from math import pi
6
+ import torch
7
+ from torch import nn
8
+ from einops import rearrange, repeat
9
+ import logging
10
+
11
+ from huggingface_hub import snapshot_download
12
+ cache_dir = snapshot_download(repo_id="jiuhai/eva_clip_vision_tower")
13
+
14
+
15
+
16
+
17
+ def broadcat(tensors, dim=-1):
18
+ num_tensors = len(tensors)
19
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
20
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
21
+ shape_len = list(shape_lens)[0]
22
+ dim = (dim + shape_len) if dim < 0 else dim
23
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
24
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
25
+ assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), "invalid dimensions for broadcastable concatentation"
26
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
27
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
28
+ expanded_dims.insert(dim, (dim, dims[dim]))
29
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
30
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
31
+ return torch.cat(tensors, dim=dim)
32
+
33
+
34
+ def rotate_half(x):
35
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
36
+ x1, x2 = x.unbind(dim=-1)
37
+ x = torch.stack((-x2, x1), dim=-1)
38
+ return rearrange(x, "... d r -> ... (d r)")
39
+
40
+
41
+ class VisionRotaryEmbeddingFast(nn.Module):
42
+ def __init__(self, dim, pt_seq_len, ft_seq_len=None, custom_freqs=None, freqs_for="lang", theta=10000, max_freq=10, num_freqs=1, patch_dropout=0.0):
43
+ super().__init__()
44
+ if custom_freqs:
45
+ freqs = custom_freqs
46
+ elif freqs_for == "lang":
47
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
48
+ elif freqs_for == "pixel":
49
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
50
+ elif freqs_for == "constant":
51
+ freqs = torch.ones(num_freqs).float()
52
+ else:
53
+ raise ValueError(f"unknown modality {freqs_for}")
54
+
55
+ if ft_seq_len is None:
56
+ ft_seq_len = pt_seq_len
57
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
58
+
59
+ freqs = torch.einsum("..., f -> ... f", t, freqs)
60
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
61
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
62
+
63
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
64
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
65
+
66
+ self.patch_dropout = patch_dropout
67
+
68
+ self.register_buffer("freqs_cos", freqs_cos)
69
+ self.register_buffer("freqs_sin", freqs_sin)
70
+
71
+ logging.info(f"Shape of rope freq: {self.freqs_cos.shape}")
72
+
73
+ def forward(self, t, patch_indices_keep=None):
74
+ if patch_indices_keep is not None:
75
+ batch = t.size()[0]
76
+ batch_indices = torch.arange(batch)
77
+ batch_indices = batch_indices[..., None]
78
+
79
+ freqs_cos = repeat(self.freqs_cos, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
80
+ freqs_sin = repeat(self.freqs_sin, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
81
+
82
+ freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
83
+ freqs_cos = rearrange(freqs_cos, "n i m j -> n m i j")
84
+ freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
85
+ freqs_sin = rearrange(freqs_sin, "n i m j -> n m i j")
86
+
87
+ return t * freqs_cos + rotate_half(t) * freqs_sin
88
+
89
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
90
+
91
+
92
+ class LayerNorm(nn.LayerNorm):
93
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
94
+
95
+ def forward(self, x: torch.Tensor):
96
+ orig_type = x.dtype
97
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
98
+ return x.to(orig_type)
99
+
100
+
101
+ class PatchDropout(nn.Module):
102
+ """
103
+ https://arxiv.org/abs/2212.00794
104
+ """
105
+
106
+ def __init__(self, prob, exclude_first_token=True):
107
+ super().__init__()
108
+ assert 0 <= prob < 1.
109
+ self.prob = prob
110
+ self.exclude_first_token = exclude_first_token # exclude CLS token
111
+ print(f"os.getenv('RoPE')={os.getenv('RoPE')}")
112
+
113
+ def forward(self, x):
114
+ if not self.training or self.prob == 0.:
115
+ return x
116
+
117
+ if self.exclude_first_token:
118
+ cls_tokens, x = x[:, :1], x[:, 1:]
119
+ else:
120
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
121
+
122
+ batch = x.size()[0]
123
+ num_tokens = x.size()[1]
124
+
125
+ batch_indices = torch.arange(batch)
126
+ batch_indices = batch_indices[..., None]
127
+
128
+ keep_prob = 1 - self.prob
129
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
130
+
131
+ rand = torch.randn(batch, num_tokens)
132
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
133
+
134
+ x = x[batch_indices, patch_indices_keep]
135
+
136
+ if self.exclude_first_token:
137
+ x = torch.cat((cls_tokens, x), dim=1)
138
+
139
+ if self.training and os.getenv('RoPE') == '1':
140
+ return x, patch_indices_keep
141
+
142
+ return x
143
+
144
+
145
+ # --------------------------------------------------------
146
+ # Adapted from https://github.com/microsoft/unilm/tree/master/beit
147
+ # --------------------------------------------------------
148
+ import math
149
+ import os
150
+ import torch.nn as nn
151
+ import torch.nn.functional as F
152
+
153
+ try:
154
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
155
+ except:
156
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
157
+
158
+ if os.getenv("ENV_TYPE") == "deepspeed":
159
+ try:
160
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
161
+ except:
162
+ from torch.utils.checkpoint import checkpoint
163
+ else:
164
+ from torch.utils.checkpoint import checkpoint
165
+
166
+ try:
167
+ import xformers.ops as xops
168
+ except ImportError:
169
+ xops = None
170
+ # print("Please 'pip install xformers'")
171
+
172
+
173
+ class DropPath(nn.Module):
174
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
175
+ """
176
+ def __init__(self, drop_prob=None):
177
+ super(DropPath, self).__init__()
178
+ self.drop_prob = drop_prob
179
+
180
+ def forward(self, x):
181
+ return drop_path(x, self.drop_prob, self.training)
182
+
183
+ def extra_repr(self) -> str:
184
+ return 'p={}'.format(self.drop_prob)
185
+
186
+
187
+ class Mlp(nn.Module):
188
+ def __init__(
189
+ self,
190
+ in_features,
191
+ hidden_features=None,
192
+ out_features=None,
193
+ act_layer=nn.GELU,
194
+ norm_layer=nn.LayerNorm,
195
+ drop=0.,
196
+ subln=False,
197
+
198
+ ):
199
+ super().__init__()
200
+ out_features = out_features or in_features
201
+ hidden_features = hidden_features or in_features
202
+ self.fc1 = nn.Linear(in_features, hidden_features)
203
+ self.act = act_layer()
204
+
205
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
206
+
207
+ self.fc2 = nn.Linear(hidden_features, out_features)
208
+ self.drop = nn.Dropout(drop)
209
+
210
+ def forward(self, x):
211
+ x = self.fc1(x)
212
+ x = self.act(x)
213
+ # x = self.drop(x)
214
+ # commit this for the orignal BERT implement
215
+ x = self.ffn_ln(x)
216
+
217
+ x = self.fc2(x)
218
+ x = self.drop(x)
219
+ return x
220
+
221
+
222
+ class SwiGLU(nn.Module):
223
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
224
+ norm_layer=nn.LayerNorm, subln=False):
225
+ super().__init__()
226
+ out_features = out_features or in_features
227
+ hidden_features = hidden_features or in_features
228
+
229
+ self.w1 = nn.Linear(in_features, hidden_features)
230
+ self.w2 = nn.Linear(in_features, hidden_features)
231
+
232
+ self.act = act_layer()
233
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
234
+ self.w3 = nn.Linear(hidden_features, out_features)
235
+
236
+ self.drop = nn.Dropout(drop)
237
+
238
+ def forward(self, x):
239
+ x1 = self.w1(x)
240
+ x2 = self.w2(x)
241
+ hidden = self.act(x1) * x2
242
+ x = self.ffn_ln(hidden)
243
+ x = self.w3(x)
244
+ x = self.drop(x)
245
+ return x
246
+
247
+
248
+ class Attention(nn.Module):
249
+ def __init__(
250
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
251
+ proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
252
+ super().__init__()
253
+ self.num_heads = num_heads
254
+ head_dim = dim // num_heads
255
+ if attn_head_dim is not None:
256
+ head_dim = attn_head_dim
257
+ all_head_dim = head_dim * self.num_heads
258
+ self.scale = qk_scale or head_dim ** -0.5
259
+
260
+ self.subln = subln
261
+ if self.subln:
262
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
263
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
264
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
265
+ else:
266
+ if qkv_bias:
267
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=True)
268
+ else:
269
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
270
+
271
+ # if qkv_bias:
272
+ # self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
273
+ # self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
274
+ # else:
275
+ # self.q_bias = None
276
+ # self.v_bias = None
277
+
278
+ self.window_size = None
279
+ self.relative_position_bias_table = None
280
+ self.relative_position_index = None
281
+
282
+ self.attn_drop = nn.Dropout(attn_drop)
283
+ self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
284
+ # self.proj = nn.Linear(all_head_dim, all_head_dim)
285
+ self.proj = nn.Linear(all_head_dim, dim)
286
+ self.proj_drop = nn.Dropout(proj_drop)
287
+ self.xattn = xattn
288
+ self.xattn_drop = attn_drop
289
+
290
+ self.rope = rope
291
+
292
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
293
+ B, N, C = x.shape
294
+ if self.subln:
295
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
296
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
297
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
298
+
299
+ q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
300
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
301
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
302
+ else:
303
+
304
+ # qkv_bias = None
305
+ # if self.q_bias is not None:
306
+ # qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
307
+
308
+ # qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
309
+
310
+ qkv = self.qkv(x)
311
+
312
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
313
+ q, k, v = qkv[0], qkv[1], qkv[2]
314
+
315
+ if self.rope:
316
+ q_t = q[:, :, 1:, :]
317
+ ro_q_t = self.rope(q_t)
318
+ q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
319
+
320
+ k_t = k[:, :, 1:, :]
321
+ ro_k_t = self.rope(k_t)
322
+ k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
323
+
324
+ if self.xattn:
325
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
326
+ k = k.permute(0, 2, 1, 3)
327
+ v = v.permute(0, 2, 1, 3)
328
+
329
+ x = xops.memory_efficient_attention(
330
+ q, k, v,
331
+ p=self.xattn_drop,
332
+ scale=self.scale,
333
+ )
334
+ x = x.reshape(B, N, -1)
335
+ x = self.inner_attn_ln(x)
336
+ x = self.proj(x)
337
+ x = self.proj_drop(x)
338
+ else:
339
+ q = q * self.scale
340
+ attn = (q @ k.transpose(-2, -1))
341
+
342
+ if self.relative_position_bias_table is not None:
343
+ relative_position_bias = \
344
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
345
+ self.window_size[0] * self.window_size[1] + 1,
346
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
347
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
348
+ attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
349
+
350
+ if rel_pos_bias is not None:
351
+ attn = attn + rel_pos_bias.type_as(attn)
352
+
353
+ if attn_mask is not None:
354
+ attn_mask = attn_mask.bool()
355
+ attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
356
+
357
+ attn = attn.softmax(dim=-1)
358
+ attn = self.attn_drop(attn)
359
+
360
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
361
+ x = self.inner_attn_ln(x)
362
+ x = self.proj(x)
363
+ x = self.proj_drop(x)
364
+ return x
365
+
366
+
367
+ class Block(nn.Module):
368
+
369
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
370
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
371
+ window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
372
+ subln=False, naiveswiglu=False):
373
+ super().__init__()
374
+ self.norm1 = norm_layer(dim)
375
+ self.attn = Attention(
376
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
377
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
378
+ xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
379
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
380
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
381
+ self.norm2 = norm_layer(dim)
382
+ mlp_hidden_dim = int(dim * mlp_ratio)
383
+
384
+ if naiveswiglu:
385
+ self.mlp = SwiGLU(
386
+ in_features=dim,
387
+ hidden_features=mlp_hidden_dim,
388
+ subln=subln,
389
+ norm_layer=norm_layer,
390
+ )
391
+ else:
392
+ self.mlp = Mlp(
393
+ in_features=dim,
394
+ hidden_features=mlp_hidden_dim,
395
+ act_layer=act_layer,
396
+ subln=subln,
397
+ drop=drop
398
+ )
399
+
400
+ if init_values is not None and init_values > 0:
401
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
402
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
403
+ else:
404
+ self.gamma_1, self.gamma_2 = None, None
405
+
406
+ self.postnorm = postnorm
407
+
408
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
409
+ if self.gamma_1 is None:
410
+ if self.postnorm:
411
+ x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
412
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
413
+ else:
414
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
415
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
416
+ else:
417
+ if self.postnorm:
418
+ x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
419
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
420
+ else:
421
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
422
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
423
+ return x
424
+
425
+
426
+ class PatchEmbed(nn.Module):
427
+ """ Image to Patch Embedding
428
+ """
429
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
430
+ super().__init__()
431
+ img_size = to_2tuple(img_size)
432
+ patch_size = to_2tuple(patch_size)
433
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
434
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
435
+ self.img_size = img_size
436
+ self.patch_size = patch_size
437
+ self.num_patches = num_patches
438
+
439
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
440
+
441
+ def forward(self, x, **kwargs):
442
+ B, C, H, W = x.shape
443
+ # FIXME look at relaxing size constraints
444
+ assert H == self.img_size[0] and W == self.img_size[1], \
445
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
446
+ x = self.proj(x).flatten(2).transpose(1, 2)
447
+ return x
448
+
449
+
450
+ class RelativePositionBias(nn.Module):
451
+
452
+ def __init__(self, window_size, num_heads):
453
+ super().__init__()
454
+ self.window_size = window_size
455
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
456
+ self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
457
+ # cls to token & token 2 cls & cls to cls
458
+
459
+ # get pair-wise relative position index for each token inside the window
460
+ coords_h = torch.arange(window_size[0])
461
+ coords_w = torch.arange(window_size[1])
462
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
463
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
464
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
465
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
466
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
467
+ relative_coords[:, :, 1] += window_size[1] - 1
468
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
469
+ relative_position_index = torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
470
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
471
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
472
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
473
+ relative_position_index[0, 0] = self.num_relative_distance - 1
474
+
475
+ self.register_buffer("relative_position_index", relative_position_index)
476
+
477
+ def forward(self):
478
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
479
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
480
+
481
+
482
+ class EVAVisionTransformer(nn.Module):
483
+ """Vision Transformer with support for patch or hybrid CNN input stage"""
484
+
485
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
486
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
487
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
488
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,
489
+ use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,
490
+ pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False,
491
+ ):
492
+ super().__init__()
493
+ self.image_size = img_size
494
+ # self.num_classes = num_classes
495
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
496
+
497
+ self.patch_embed = PatchEmbed(
498
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
499
+ num_patches = self.patch_embed.num_patches
500
+
501
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
502
+ if use_abs_pos_emb:
503
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
504
+ else:
505
+ self.pos_embed = None
506
+ self.pos_drop = nn.Dropout(p=drop_rate)
507
+
508
+ self.rel_pos_bias = None
509
+ self.rope = None
510
+
511
+ self.naiveswiglu = naiveswiglu
512
+
513
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
514
+ self.use_rel_pos_bias = use_rel_pos_bias
515
+ self.blocks = nn.ModuleList([
516
+ Block(
517
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
518
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
519
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
520
+ xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)
521
+ for i in range(depth)])
522
+
523
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
524
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
525
+
526
+ self.grad_checkpointing = grad_checkpointing
527
+
528
+ def fix_init_weight(self):
529
+ def rescale(param, layer_id):
530
+ param.div_(math.sqrt(2.0 * layer_id))
531
+
532
+ for layer_id, layer in enumerate(self.blocks):
533
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
534
+ if self.naiveswiglu:
535
+ rescale(layer.mlp.w3.weight.data, layer_id + 1)
536
+ else:
537
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
538
+
539
+ def get_cast_dtype(self) -> torch.dtype:
540
+ return self.blocks[0].mlp.fc2.weight.dtype
541
+
542
+ def _init_weights(self, m):
543
+ if isinstance(m, nn.Linear):
544
+ trunc_normal_(m.weight, std=0.02)
545
+ if m.bias is not None:
546
+ nn.init.constant_(m.bias, 0)
547
+ elif isinstance(m, nn.LayerNorm):
548
+ nn.init.constant_(m.bias, 0)
549
+ nn.init.constant_(m.weight, 1.0)
550
+
551
+ def get_num_layers(self):
552
+ return len(self.blocks)
553
+
554
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
555
+ assert unlocked_groups == 0, "partial locking not currently supported for this model"
556
+ for param in self.parameters():
557
+ param.requires_grad = False
558
+
559
+ @torch.jit.ignore
560
+ def set_grad_checkpointing(self, enable=True):
561
+ self.grad_checkpointing = enable
562
+
563
+ @torch.jit.ignore
564
+ def no_weight_decay(self):
565
+ return {"pos_embed", "cls_token"}
566
+
567
+ def get_classifier(self):
568
+ return self.head
569
+
570
+ def reset_classifier(self, num_classes, global_pool=""):
571
+ self.num_classes = num_classes
572
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
573
+
574
+ def forward_features(self, x):
575
+ x = self.patch_embed(x)
576
+ batch_size, seq_len, _ = x.size()
577
+
578
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
579
+ x = torch.cat((cls_tokens, x), dim=1)
580
+ if self.pos_embed is not None:
581
+ x = x + self.pos_embed
582
+ x = self.pos_drop(x)
583
+
584
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
585
+ if os.getenv('RoPE') == '1':
586
+ if self.training and not isinstance(self.patch_dropout, nn.Identity):
587
+ x, patch_indices_keep = self.patch_dropout(x)
588
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
589
+ else:
590
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
591
+ x = self.patch_dropout(x)
592
+ else:
593
+ x = self.patch_dropout(x)
594
+
595
+ rel_pos_bias = None
596
+
597
+ for blk in self.blocks:
598
+ if self.grad_checkpointing:
599
+ x = checkpoint(blk, x, (rel_pos_bias,))
600
+ else:
601
+ x = blk(x, rel_pos_bias=rel_pos_bias)
602
+
603
+ return x
604
+
605
+ def forward(self, x, return_all_features=False):
606
+
607
+ """
608
+ :return:
609
+ forward_features function returns raw features of ViT,
610
+ forward with return_all_features returns normalized features of ViT
611
+ :param x:
612
+ :param return_all_features:
613
+ """
614
+
615
+ features = self.forward_features(x) # [B, n_patch, C]
616
+ return features
617
+
618
+
619
+ def load_state_dict(checkpoint_path: str, map_location: str = "cpu", model_key: str = "model|module|state_dict", is_openai: bool = False, skip_list: list = []):
620
+ if is_openai:
621
+ model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
622
+ state_dict = model.state_dict()
623
+ for key in ["input_resolution", "context_length", "vocab_size"]:
624
+ state_dict.pop(key, None)
625
+ else:
626
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
627
+ for mk in model_key.split("|"):
628
+ if isinstance(checkpoint, dict) and mk in checkpoint:
629
+ state_dict = checkpoint[mk]
630
+ break
631
+ else:
632
+ state_dict = checkpoint
633
+ if next(iter(state_dict.items()))[0].startswith("module"):
634
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
635
+
636
+ for k in skip_list:
637
+ if k in list(state_dict.keys()):
638
+ logging.info(f"Removing key {k} from pretrained checkpoint")
639
+ del state_dict[k]
640
+
641
+ if os.getenv("RoPE") == "1":
642
+ for k in list(state_dict.keys()):
643
+ if "freqs_cos" in k or "freqs_sin" in k:
644
+ del state_dict[k]
645
+ return state_dict
646
+
647
+
648
+ def load_clip_visual_state_dict(checkpoint_path: str, map_location: str = "cpu", is_openai: bool = False, skip_list: list = []):
649
+ state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
650
+ # for k in list(state_dict.keys()):
651
+ # if not k.startswith("visual."):
652
+ # del state_dict[k]
653
+ # for k in list(state_dict.keys()):
654
+ # if k.startswith("visual."):
655
+ # new_k = k[7:]
656
+ # state_dict[new_k] = state_dict[k]
657
+ # del state_dict[k]
658
+ return state_dict
659
+
660
+
661
+ from dataclasses import dataclass
662
+ from typing import Optional, Tuple, Union
663
+
664
+ try:
665
+ from apex.normalization import FusedLayerNorm
666
+ except:
667
+ FusedLayerNorm = LayerNorm
668
+ # print("Please build and install Nvidia apex package with option '--cuda_ext' according to https://github.com/NVIDIA/apex#from-source .")
669
+
670
+
671
+ @dataclass
672
+ class CLIPVisionCfg:
673
+ layers: Union[Tuple[int, int, int, int], int] = 12
674
+ width: int = 768
675
+ head_width: int = 64
676
+ mlp_ratio: float = 4.0
677
+ patch_size: int = 16
678
+ image_size: Union[Tuple[int, int], int] = 224
679
+ ls_init_value: Optional[float] = None # layer scale initial value
680
+ patch_dropout: float = 0.0 # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
681
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
682
+ drop_path_rate: Optional[float] = None # drop path rate
683
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
684
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
685
+ timm_pool: str = "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
686
+ timm_proj: str = "linear" # linear projection for timm model output ('linear', 'mlp', '')
687
+ timm_proj_bias: bool = False # enable bias final projection
688
+ eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
689
+ qkv_bias: bool = True
690
+ fusedLN: bool = False
691
+ xattn: bool = False
692
+ postnorm: bool = False
693
+ rope: bool = False
694
+ pt_hw_seq_len: int = 16 # 224/14
695
+ intp_freq: bool = False
696
+ naiveswiglu: bool = False
697
+ subln: bool = False
698
+
699
+
700
+ def create_norm_layer_factory(use_fused_ln, eps=1e-6):
701
+ # Otherwise, use the standard LayerNorm
702
+ return lambda num_features: nn.LayerNorm(num_features, eps=eps)
703
+
704
+
705
+ def _build_vision_tower(vision_tower_path: str, embed_dim: int, vision_cfg: CLIPVisionCfg, **kwargs):
706
+ if isinstance(vision_cfg, dict):
707
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
708
+
709
+ if vision_cfg.eva_model_name:
710
+ vision_heads = vision_cfg.width // vision_cfg.head_width
711
+ # Determine the appropriate norm layer factory based on the configuration
712
+ norm_layer_factory = create_norm_layer_factory(vision_cfg.fusedLN, eps=1e-6)
713
+
714
+ # breakpoint()
715
+ visual = EVAVisionTransformer(
716
+ img_size=vision_cfg.image_size,
717
+ patch_size=vision_cfg.patch_size,
718
+ num_classes=embed_dim,
719
+ use_mean_pooling=vision_cfg.global_average_pool, # False
720
+ init_values=vision_cfg.ls_init_value,
721
+ patch_dropout=vision_cfg.patch_dropout,
722
+ embed_dim=vision_cfg.width,
723
+ depth=vision_cfg.layers,
724
+ num_heads=vision_heads,
725
+ mlp_ratio=vision_cfg.mlp_ratio,
726
+ qkv_bias=vision_cfg.qkv_bias,
727
+ drop_path_rate=vision_cfg.drop_path_rate,
728
+ norm_layer=norm_layer_factory,
729
+ xattn=vision_cfg.xattn,
730
+ rope=vision_cfg.rope,
731
+ postnorm=vision_cfg.postnorm,
732
+ pt_hw_seq_len=vision_cfg.pt_hw_seq_len, # 224/14
733
+ intp_freq=vision_cfg.intp_freq,
734
+ naiveswiglu=vision_cfg.naiveswiglu,
735
+ subln=vision_cfg.subln,
736
+ )
737
+ # breakpoint()
738
+ state_dict = load_clip_visual_state_dict(vision_tower_path)
739
+ incompatible_keys = visual.load_state_dict(state_dict, strict=False)
740
+ print("EVA-CLIP incompatible_keys:", incompatible_keys)
741
+
742
+ return visual
743
+
744
+
745
+ class EVAEncoderWrapper(nn.Module):
746
+ def __init__(self, vision_tower_pretrained, config):
747
+ super(EVAEncoderWrapper, self).__init__()
748
+ self.config = config
749
+ self.config["vision_tower_path"] = os.path.join(cache_dir, "pytorch_model.bin")
750
+ self.model = _build_vision_tower(**self.config)
751
+
752
+ def forward(self, image, **kwargs):
753
+ encode = self.model(image, return_all_features=True)[:, 1:, :] # remove the CLS token
754
+ return encode
755
+
756
+ @property
757
+ def dtype(self):
758
+ return list(self.parameters())[-1].dtype
759
+
760
+ @property
761
+ def device(self):
762
+ return list(self.parameters())[-1].device
blip3o/model/multimodal_encoder/eva_clip/factory.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import re
6
+ from copy import deepcopy
7
+ from pathlib import Path
8
+ from typing import Optional, Tuple, Union, Dict, Any
9
+ import torch
10
+
11
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
12
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
13
+
14
+
15
+ def _natural_key(string_):
16
+ return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
17
+
18
+
19
+ def _rescan_model_configs():
20
+ global _MODEL_CONFIGS
21
+
22
+ config_ext = (".json",)
23
+ config_files = []
24
+ for config_path in _MODEL_CONFIG_PATHS:
25
+ if config_path.is_file() and config_path.suffix in config_ext:
26
+ config_files.append(config_path)
27
+ elif config_path.is_dir():
28
+ for ext in config_ext:
29
+ config_files.extend(config_path.glob(f"*{ext}"))
30
+ for cf in config_files:
31
+ with open(cf, "r", encoding="utf8") as f:
32
+ model_cfg = json.load(f)
33
+ if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")):
34
+ _MODEL_CONFIGS[cf.stem] = model_cfg
35
+
36
+ _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])))
37
+
38
+
39
+ _rescan_model_configs() # initial populate of model config registry
40
+
41
+
42
+ def list_models():
43
+ """enumerate available model architectures based on config files"""
44
+ return list(_MODEL_CONFIGS.keys())
45
+
46
+
47
+ def add_model_config(path):
48
+ """add model config path or file and update registry"""
49
+ if not isinstance(path, Path):
50
+ path = Path(path)
51
+ _MODEL_CONFIG_PATHS.append(path)
52
+ _rescan_model_configs()
53
+
54
+
55
+ def get_model_config(model_name):
56
+ if model_name in _MODEL_CONFIGS:
57
+ return deepcopy(_MODEL_CONFIGS[model_name])
58
+ else:
59
+ return None
blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-18B.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1536,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 48,
6
+ "width": 5120,
7
+ "head_width": 128,
8
+ "mlp_ratio": 5,
9
+ "patch_size": 14,
10
+ "eva_model_name": "eva-clip-18b-14-x",
11
+ "drop_path_rate": 0,
12
+ "qkv_bias": false,
13
+ "xattn": true,
14
+ "postnorm": true,
15
+ "fusedLN": false,
16
+ "use_rms_norm": true
17
+ },
18
+ "text_cfg": {
19
+ "context_length": 77,
20
+ "vocab_size": 49408,
21
+ "width": 1280,
22
+ "heads": 20,
23
+ "layers": 32,
24
+ "xattn": false,
25
+ "fusedLN": false
26
+ }
27
+ }
blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B-plus.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1280,
3
+ "vision_cfg": {
4
+ "image_size": 448,
5
+ "layers": 32,
6
+ "width": 4096,
7
+ "head_width": 128,
8
+ "mlp_ratio": 5,
9
+ "patch_size": 14,
10
+ "eva_model_name": "eva-clip-8b-14-plus-x",
11
+ "drop_path_rate": 0,
12
+ "qkv_bias": false,
13
+ "xattn": true,
14
+ "postnorm": false,
15
+ "fusedLN": false,
16
+ "use_rms_norm": true
17
+ },
18
+ "text_cfg": {
19
+ "context_length": 77,
20
+ "vocab_size": 49408,
21
+ "width": 1280,
22
+ "heads": 20,
23
+ "layers": 32,
24
+ "xattn": false,
25
+ "fusedLN": false
26
+ }
27
+ }
blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1280,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 32,
6
+ "width": 4096,
7
+ "head_width": 128,
8
+ "mlp_ratio": 5,
9
+ "patch_size": 14,
10
+ "eva_model_name": "eva-clip-8b-14-x",
11
+ "drop_path_rate": 0,
12
+ "qkv_bias": false,
13
+ "xattn": true,
14
+ "postnorm": false,
15
+ "fusedLN": false,
16
+ "use_rms_norm": true
17
+ },
18
+ "text_cfg": {
19
+ "context_length": 77,
20
+ "vocab_size": 49408,
21
+ "width": 1280,
22
+ "heads": 20,
23
+ "layers": 32,
24
+ "xattn": false,
25
+ "fusedLN": false
26
+ }
27
+ }
blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-B-16.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "patch_size": 16,
8
+ "eva_model_name": "eva-clip-b-16",
9
+ "ls_init_value": 0.1,
10
+ "drop_path_rate": 0.0
11
+ },
12
+ "text_cfg": {
13
+ "context_length": 77,
14
+ "vocab_size": 49408,
15
+ "width": 512,
16
+ "heads": 8,
17
+ "layers": 12
18
+ }
19
+ }
blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 40,
6
+ "width": 1408,
7
+ "head_width": 88,
8
+ "mlp_ratio": 4.3637,
9
+ "patch_size": 14,
10
+ "eva_model_name": "eva-clip-g-14-x",
11
+ "drop_path_rate": 0,
12
+ "xattn": true,
13
+ "fusedLN": true
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 1024,
19
+ "heads": 16,
20
+ "layers": 24,
21
+ "xattn": false,
22
+ "fusedLN": true
23
+ }
24
+ }
blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 40,
6
+ "width": 1408,
7
+ "head_width": 88,
8
+ "mlp_ratio": 4.3637,
9
+ "patch_size": 14,
10
+ "eva_model_name": "eva-clip-g-14-x",
11
+ "drop_path_rate": 0.4,
12
+ "xattn": true,
13
+ "fusedLN": true
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 768,
19
+ "heads": 12,
20
+ "layers": 12,
21
+ "xattn": false,
22
+ "fusedLN": true
23
+ }
24
+ }
blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-B-16.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 12,
6
+ "width": 768,
7
+ "head_width": 64,
8
+ "patch_size": 16,
9
+ "mlp_ratio": 2.6667,
10
+ "eva_model_name": "eva-clip-b-16-X",
11
+ "drop_path_rate": 0.0,
12
+ "xattn": true,
13
+ "fusedLN": true,
14
+ "rope": true,
15
+ "pt_hw_seq_len": 16,
16
+ "intp_freq": true,
17
+ "naiveswiglu": true,
18
+ "subln": true
19
+ },
20
+ "text_cfg": {
21
+ "context_length": 77,
22
+ "vocab_size": 49408,
23
+ "width": 512,
24
+ "heads": 8,
25
+ "layers": 12,
26
+ "xattn": true,
27
+ "fusedLN": true
28
+ }
29
+ }
blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14-336.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 336,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "drop_path_rate": 0,
8
+ "head_width": 64,
9
+ "mlp_ratio": 2.6667,
10
+ "patch_size": 14,
11
+ "eva_model_name": "eva-clip-l-14-336",
12
+ "xattn": true,
13
+ "fusedLN": true,
14
+ "rope": true,
15
+ "pt_hw_seq_len": 16,
16
+ "intp_freq": true,
17
+ "naiveswiglu": true,
18
+ "subln": true
19
+ },
20
+ "text_cfg": {
21
+ "context_length": 77,
22
+ "vocab_size": 49408,
23
+ "width": 768,
24
+ "heads": 12,
25
+ "layers": 12,
26
+ "xattn": false,
27
+ "fusedLN": true
28
+ }
29
+ }
blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 24,
6
+ "width": 1024,
7
+ "drop_path_rate": 0,
8
+ "head_width": 64,
9
+ "mlp_ratio": 2.6667,
10
+ "patch_size": 14,
11
+ "eva_model_name": "eva-clip-l-14",
12
+ "xattn": true,
13
+ "fusedLN": true,
14
+ "rope": true,
15
+ "pt_hw_seq_len": 16,
16
+ "intp_freq": true,
17
+ "naiveswiglu": true,
18
+ "subln": true
19
+ },
20
+ "text_cfg": {
21
+ "context_length": 77,
22
+ "vocab_size": 49408,
23
+ "width": 768,
24
+ "heads": 12,
25
+ "layers": 12,
26
+ "xattn": false,
27
+ "fusedLN": true
28
+ }
29
+ }
blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 64,
6
+ "width": 1792,
7
+ "head_width": 112,
8
+ "mlp_ratio": 8.571428571428571,
9
+ "patch_size": 14,
10
+ "eva_model_name": "eva-clip-4b-14-x",
11
+ "drop_path_rate": 0,
12
+ "xattn": true,
13
+ "postnorm": true,
14
+ "fusedLN": true
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 1280,
20
+ "heads": 20,
21
+ "layers": 32,
22
+ "xattn": false,
23
+ "fusedLN": true
24
+ }
25
+ }
26
+
27
+
28
+
blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": 64,
6
+ "width": 1792,
7
+ "head_width": 112,
8
+ "mlp_ratio": 8.571428571428571,
9
+ "patch_size": 14,
10
+ "eva_model_name": "eva-clip-4b-14-x",
11
+ "drop_path_rate": 0,
12
+ "xattn": true,
13
+ "postnorm": true,
14
+ "fusedLN": true
15
+ },
16
+ "text_cfg": {
17
+ "context_length": 77,
18
+ "vocab_size": 49408,
19
+ "width": 1024,
20
+ "heads": 16,
21
+ "layers": 24,
22
+ "xattn": false,
23
+ "fusedLN": true
24
+ }
25
+ }