Upload 69 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- blip3o/.DS_Store +0 -0
- blip3o/__init__.py +1 -0
- blip3o/constants.py +26 -0
- blip3o/conversation.py +479 -0
- blip3o/mm_utils.py +247 -0
- blip3o/model/__init__.py +3 -0
- blip3o/model/apply_delta.py +48 -0
- blip3o/model/blip3o_arch.py +415 -0
- blip3o/model/builder.py +54 -0
- blip3o/model/consolidate.py +25 -0
- blip3o/model/language_model/blip3o_llama.py +413 -0
- blip3o/model/language_model/blip3o_qwen.py +420 -0
- blip3o/model/lumina_nextdit2d.py +365 -0
- blip3o/model/make_delta.py +48 -0
- blip3o/model/multimodal_encoder/builder.py +63 -0
- blip3o/model/multimodal_encoder/clip_encoder.py +172 -0
- blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/__init__.py +9 -0
- blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
- blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/constants.py +2 -0
- blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/eva_vit_model.py +571 -0
- blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/factory.py +528 -0
- blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_configs.py +57 -0
- blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_model.py +240 -0
- blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/loss.py +123 -0
- blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/model.py +429 -0
- blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/modified_resnet.py +179 -0
- blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/openai.py +144 -0
- blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/pretrained.py +314 -0
- blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/rope.py +131 -0
- blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/timm_model.py +114 -0
- blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/tokenizer.py +205 -0
- blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/transform.py +104 -0
- blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/transformer.py +683 -0
- blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/utils.py +321 -0
- blip3o/model/multimodal_encoder/dev_eva_clip/eva_vit.py +140 -0
- blip3o/model/multimodal_encoder/eva_clip/eva_clip_encoder.py +75 -0
- blip3o/model/multimodal_encoder/eva_clip/eva_clip_processors.py +74 -0
- blip3o/model/multimodal_encoder/eva_clip/eva_vit.py +762 -0
- blip3o/model/multimodal_encoder/eva_clip/factory.py +59 -0
- blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-18B.json +27 -0
- blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B-plus.json +27 -0
- blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B.json +27 -0
- blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-B-16.json +19 -0
- blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json +24 -0
- blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14.json +24 -0
- blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-B-16.json +29 -0
- blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14-336.json +29 -0
- blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14.json +29 -0
- blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json +28 -0
- blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14.json +25 -0
blip3o/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
blip3o/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model import blip3oLlamaForCausalLM
|
blip3o/constants.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
3 |
+
|
4 |
+
LOGDIR = "."
|
5 |
+
|
6 |
+
# Model Constants
|
7 |
+
IGNORE_INDEX = -100
|
8 |
+
# IMAGE_TOKEN_INDEX = -200
|
9 |
+
|
10 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
11 |
+
DEFAULT_IM_START_TOKEN = "[IMG]"
|
12 |
+
DEFAULT_IM_END_TOKEN = "[/IMG]"
|
13 |
+
|
14 |
+
|
15 |
+
# IMAGE_TOKEN_IDX = 32002
|
16 |
+
# DEFAULT_IM_START_TOKEN_IDX = 32000
|
17 |
+
# DEFAULT_IM_END_TOKEN_IDX = 32001
|
18 |
+
|
19 |
+
IMAGE_TOKEN_IDX = 151667
|
20 |
+
DEFAULT_IM_START_TOKEN_IDX = 128257
|
21 |
+
DEFAULT_IM_END_TOKEN_IDX = 128258
|
22 |
+
UND_IMAGE_TOKEN_IDX = 151655
|
23 |
+
# N_QUERY = 729
|
24 |
+
|
25 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
26 |
+
IMAGE_PLACEHOLDER = "<image-placeholder>"
|
blip3o/conversation.py
ADDED
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Tuple
|
4 |
+
import base64
|
5 |
+
from io import BytesIO
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
|
9 |
+
class SeparatorStyle(Enum):
|
10 |
+
"""Different separator style."""
|
11 |
+
SINGLE = auto()
|
12 |
+
TWO = auto()
|
13 |
+
MPT = auto()
|
14 |
+
PLAIN = auto()
|
15 |
+
LLAMA_2 = auto()
|
16 |
+
CHATML = auto()
|
17 |
+
QWEN = auto()
|
18 |
+
|
19 |
+
|
20 |
+
@dataclasses.dataclass
|
21 |
+
class Conversation:
|
22 |
+
"""A class that keeps all conversation history."""
|
23 |
+
system: str
|
24 |
+
roles: List[str]
|
25 |
+
messages: List[List[str]]
|
26 |
+
offset: int
|
27 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
28 |
+
sep: str = "###"
|
29 |
+
sep2: str = None
|
30 |
+
version: str = "Unknown"
|
31 |
+
|
32 |
+
skip_next: bool = False
|
33 |
+
|
34 |
+
def get_prompt(self):
|
35 |
+
messages = self.messages
|
36 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
37 |
+
messages = self.messages.copy()
|
38 |
+
init_role, init_msg = messages[0].copy()
|
39 |
+
init_msg = init_msg[0]
|
40 |
+
if "mmtag" in self.version:
|
41 |
+
init_msg = init_msg.replace("<image>", "").strip()
|
42 |
+
messages[0] = (init_role, init_msg)
|
43 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
44 |
+
messages.insert(1, (self.roles[1], "Received."))
|
45 |
+
elif not init_msg.startswith("<image>"):
|
46 |
+
init_msg = init_msg.replace("<image>", "").strip()
|
47 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
48 |
+
else:
|
49 |
+
messages[0] = (init_role, init_msg)
|
50 |
+
|
51 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
52 |
+
ret = self.system + self.sep
|
53 |
+
for role, message in messages:
|
54 |
+
if message:
|
55 |
+
if type(message) is tuple:
|
56 |
+
message, _, _ = message
|
57 |
+
ret += role + ": " + message + self.sep
|
58 |
+
else:
|
59 |
+
ret += role + ":"
|
60 |
+
|
61 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
62 |
+
seps = [self.sep, self.sep2]
|
63 |
+
ret = self.system + seps[0]
|
64 |
+
for i, (role, message) in enumerate(messages):
|
65 |
+
if message:
|
66 |
+
if type(message) is tuple:
|
67 |
+
message, _, _ = message
|
68 |
+
ret += role + ": " + message + seps[i % 2]
|
69 |
+
else:
|
70 |
+
ret += role + ":"
|
71 |
+
|
72 |
+
elif self.sep_style == SeparatorStyle.CHATML:
|
73 |
+
ret = "" if self.system == "" else self.system + self.sep + "\n"
|
74 |
+
for role, message in messages:
|
75 |
+
if message:
|
76 |
+
if type(message) is tuple:
|
77 |
+
message, images, _ = message
|
78 |
+
message = "<image>" * len(images) + message
|
79 |
+
ret += role + "\n" + message + self.sep + "\n"
|
80 |
+
else:
|
81 |
+
ret += role + "\n"
|
82 |
+
return ret
|
83 |
+
|
84 |
+
elif self.sep_style == SeparatorStyle.LLAMA_3:
|
85 |
+
if self.tokenizer is None:
|
86 |
+
raise ValueError("Llama 3 tokenizer is not available. Make sure you have the necessary permissions.")
|
87 |
+
chat_template_messages = [{"role": "system", "content": self.system}]
|
88 |
+
for role, message in messages:
|
89 |
+
if message:
|
90 |
+
if type(message) is tuple:
|
91 |
+
message, images = message
|
92 |
+
message = "<image>" * len(images) + message
|
93 |
+
chat_template_messages.append({"role": role, "content": message})
|
94 |
+
|
95 |
+
# print(chat_template_messages)
|
96 |
+
return self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False, add_generation_prompt=True)
|
97 |
+
# ret = "" if self.system == "" else self.system + self.sep + "\n"
|
98 |
+
# for role, message in messages:
|
99 |
+
# if message:
|
100 |
+
# if type(message) is tuple:
|
101 |
+
# message, images = message
|
102 |
+
# message = "<image>" * len(images) + message
|
103 |
+
# ret += role + "\n" + message + self.sep + "\n"
|
104 |
+
# else:
|
105 |
+
# ret += role + "\n"
|
106 |
+
# return ret
|
107 |
+
|
108 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
109 |
+
ret = self.system + self.sep
|
110 |
+
for role, message in messages:
|
111 |
+
if message:
|
112 |
+
if type(message) is tuple:
|
113 |
+
message, _, _ = message
|
114 |
+
ret += role + message + self.sep
|
115 |
+
else:
|
116 |
+
ret += role
|
117 |
+
|
118 |
+
elif self.sep_style == SeparatorStyle.GEMMA:
|
119 |
+
ret = ""
|
120 |
+
for i, (role, message) in enumerate(messages):
|
121 |
+
assert role == self.roles[i % 2], "Conversation should alternate user/assistant/user/assistant/..."
|
122 |
+
if message:
|
123 |
+
if type(message) is tuple:
|
124 |
+
message, _, _ = message
|
125 |
+
ret += role + message + self.sep
|
126 |
+
else:
|
127 |
+
ret += role
|
128 |
+
|
129 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
130 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
|
131 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
132 |
+
ret = ""
|
133 |
+
|
134 |
+
for i, (role, message) in enumerate(messages):
|
135 |
+
if i == 0:
|
136 |
+
assert message, "first message should not be none"
|
137 |
+
assert role == self.roles[0], "first message should come from user"
|
138 |
+
if message:
|
139 |
+
if type(message) is tuple:
|
140 |
+
message, _, _ = message
|
141 |
+
if i == 0:
|
142 |
+
message = wrap_sys(self.system) + message
|
143 |
+
if i % 2 == 0:
|
144 |
+
message = wrap_inst(message)
|
145 |
+
ret += self.sep + message
|
146 |
+
else:
|
147 |
+
ret += " " + message + " " + self.sep2
|
148 |
+
else:
|
149 |
+
ret += ""
|
150 |
+
ret = ret.lstrip(self.sep)
|
151 |
+
|
152 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
153 |
+
seps = [self.sep, self.sep2]
|
154 |
+
ret = self.system
|
155 |
+
for i, (role, message) in enumerate(messages):
|
156 |
+
if message:
|
157 |
+
if type(message) is tuple:
|
158 |
+
message, _, _ = message
|
159 |
+
ret += message + seps[i % 2]
|
160 |
+
else:
|
161 |
+
ret += ""
|
162 |
+
else:
|
163 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
164 |
+
|
165 |
+
return ret
|
166 |
+
|
167 |
+
def append_message(self, role, message):
|
168 |
+
self.messages.append([role, message])
|
169 |
+
|
170 |
+
def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
|
171 |
+
if image_process_mode == "Pad":
|
172 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
173 |
+
width, height = pil_img.size
|
174 |
+
if width == height:
|
175 |
+
return pil_img
|
176 |
+
elif width > height:
|
177 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
178 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
179 |
+
return result
|
180 |
+
else:
|
181 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
182 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
183 |
+
return result
|
184 |
+
|
185 |
+
image = expand2square(image)
|
186 |
+
elif image_process_mode in ["Default", "Crop"]:
|
187 |
+
pass
|
188 |
+
elif image_process_mode == "Resize":
|
189 |
+
image = image.resize((336, 336))
|
190 |
+
else:
|
191 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
192 |
+
if max(image.size) > max_len:
|
193 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
194 |
+
aspect_ratio = max_hw / min_hw
|
195 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
196 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
197 |
+
W, H = image.size
|
198 |
+
if H > W:
|
199 |
+
H, W = longest_edge, shortest_edge
|
200 |
+
else:
|
201 |
+
H, W = shortest_edge, longest_edge
|
202 |
+
image = image.resize((W, H))
|
203 |
+
if return_pil:
|
204 |
+
return image
|
205 |
+
else:
|
206 |
+
buffered = BytesIO()
|
207 |
+
image.save(buffered, format=image_format)
|
208 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
209 |
+
return img_b64_str
|
210 |
+
|
211 |
+
def get_images(self, return_pil=False):
|
212 |
+
images = []
|
213 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
214 |
+
if i % 2 == 0:
|
215 |
+
if type(msg) is tuple:
|
216 |
+
msg, image, image_process_mode = msg
|
217 |
+
image = self.process_image(image, image_process_mode, return_pil=return_pil)
|
218 |
+
images.append(image)
|
219 |
+
return images
|
220 |
+
|
221 |
+
def to_gradio_chatbot(self):
|
222 |
+
ret = []
|
223 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
224 |
+
if i % 2 == 0:
|
225 |
+
if type(msg) is tuple:
|
226 |
+
msg, image, image_process_mode = msg
|
227 |
+
img_b64_str = self.process_image(
|
228 |
+
image, "Default", return_pil=False,
|
229 |
+
image_format='JPEG')
|
230 |
+
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
|
231 |
+
msg = img_str + msg.replace('<image>', '').strip()
|
232 |
+
ret.append([msg, None])
|
233 |
+
else:
|
234 |
+
ret.append([msg, None])
|
235 |
+
else:
|
236 |
+
ret[-1][-1] = msg
|
237 |
+
return ret
|
238 |
+
|
239 |
+
def copy(self):
|
240 |
+
return Conversation(
|
241 |
+
system=self.system,
|
242 |
+
roles=self.roles,
|
243 |
+
messages=[[x, y] for x, y in self.messages],
|
244 |
+
offset=self.offset,
|
245 |
+
sep_style=self.sep_style,
|
246 |
+
sep=self.sep,
|
247 |
+
sep2=self.sep2,
|
248 |
+
version=self.version)
|
249 |
+
|
250 |
+
def dict(self):
|
251 |
+
if len(self.get_images()) > 0:
|
252 |
+
return {
|
253 |
+
"system": self.system,
|
254 |
+
"roles": self.roles,
|
255 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
256 |
+
"offset": self.offset,
|
257 |
+
"sep": self.sep,
|
258 |
+
"sep2": self.sep2,
|
259 |
+
}
|
260 |
+
return {
|
261 |
+
"system": self.system,
|
262 |
+
"roles": self.roles,
|
263 |
+
"messages": self.messages,
|
264 |
+
"offset": self.offset,
|
265 |
+
"sep": self.sep,
|
266 |
+
"sep2": self.sep2,
|
267 |
+
}
|
268 |
+
|
269 |
+
|
270 |
+
conv_vicuna_v0 = Conversation(
|
271 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
272 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
273 |
+
roles=("Human", "Assistant"),
|
274 |
+
messages=(
|
275 |
+
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
276 |
+
("Assistant",
|
277 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
278 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
279 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
280 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
281 |
+
"renewable and non-renewable energy sources:\n"
|
282 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
283 |
+
"energy sources are finite and will eventually run out.\n"
|
284 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
285 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
286 |
+
"and other negative effects.\n"
|
287 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
288 |
+
"have lower operational costs than non-renewable sources.\n"
|
289 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
290 |
+
"locations than non-renewable sources.\n"
|
291 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
292 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
293 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
294 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
295 |
+
),
|
296 |
+
offset=2,
|
297 |
+
sep_style=SeparatorStyle.SINGLE,
|
298 |
+
sep="###",
|
299 |
+
)
|
300 |
+
|
301 |
+
conv_vicuna_v1 = Conversation(
|
302 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
303 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
304 |
+
roles=("USER", "ASSISTANT"),
|
305 |
+
version="v1",
|
306 |
+
messages=(),
|
307 |
+
offset=0,
|
308 |
+
sep_style=SeparatorStyle.TWO,
|
309 |
+
sep=" ",
|
310 |
+
sep2="</s>",
|
311 |
+
)
|
312 |
+
|
313 |
+
conv_llama_2 = Conversation(
|
314 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
315 |
+
|
316 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
317 |
+
roles=("USER", "ASSISTANT"),
|
318 |
+
version="llama_v2",
|
319 |
+
messages=(),
|
320 |
+
offset=0,
|
321 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
322 |
+
sep="<s>",
|
323 |
+
sep2="</s>",
|
324 |
+
)
|
325 |
+
|
326 |
+
|
327 |
+
conv_blip3o_llama_2 = Conversation(
|
328 |
+
system="You are a helpful language and vision assistant. "
|
329 |
+
"You are able to understand the visual content that the user provides, "
|
330 |
+
"and assist the user with a variety of tasks using natural language.",
|
331 |
+
roles=("USER", "ASSISTANT"),
|
332 |
+
version="llama_v2",
|
333 |
+
messages=(),
|
334 |
+
offset=0,
|
335 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
336 |
+
sep="<s>",
|
337 |
+
sep2="</s>",
|
338 |
+
)
|
339 |
+
|
340 |
+
conv_mpt = Conversation(
|
341 |
+
system="""<|im_start|>system
|
342 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
343 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
344 |
+
version="mpt",
|
345 |
+
messages=(),
|
346 |
+
offset=0,
|
347 |
+
sep_style=SeparatorStyle.MPT,
|
348 |
+
sep="<|im_end|>",
|
349 |
+
)
|
350 |
+
|
351 |
+
conv_blip3o_plain = Conversation(
|
352 |
+
system="",
|
353 |
+
roles=("", ""),
|
354 |
+
messages=(
|
355 |
+
),
|
356 |
+
offset=0,
|
357 |
+
sep_style=SeparatorStyle.PLAIN,
|
358 |
+
sep="\n",
|
359 |
+
)
|
360 |
+
|
361 |
+
conv_blip3o_v0 = Conversation(
|
362 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
363 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
364 |
+
roles=("Human", "Assistant"),
|
365 |
+
messages=(
|
366 |
+
),
|
367 |
+
offset=0,
|
368 |
+
sep_style=SeparatorStyle.SINGLE,
|
369 |
+
sep="###",
|
370 |
+
)
|
371 |
+
|
372 |
+
conv_blip3o_v0_mmtag = Conversation(
|
373 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
374 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
375 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
376 |
+
roles=("Human", "Assistant"),
|
377 |
+
messages=(
|
378 |
+
),
|
379 |
+
offset=0,
|
380 |
+
sep_style=SeparatorStyle.SINGLE,
|
381 |
+
sep="###",
|
382 |
+
version="v0_mmtag",
|
383 |
+
)
|
384 |
+
|
385 |
+
conv_blip3o_v1 = Conversation(
|
386 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
387 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
388 |
+
roles=("USER", "ASSISTANT"),
|
389 |
+
version="v1",
|
390 |
+
messages=(),
|
391 |
+
offset=0,
|
392 |
+
sep_style=SeparatorStyle.TWO,
|
393 |
+
sep=" ",
|
394 |
+
sep2="</s>",
|
395 |
+
)
|
396 |
+
|
397 |
+
conv_blip3o_v1_mmtag = Conversation(
|
398 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
399 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
400 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
401 |
+
roles=("USER", "ASSISTANT"),
|
402 |
+
messages=(),
|
403 |
+
offset=0,
|
404 |
+
sep_style=SeparatorStyle.TWO,
|
405 |
+
sep=" ",
|
406 |
+
sep2="</s>",
|
407 |
+
version="v1_mmtag",
|
408 |
+
)
|
409 |
+
|
410 |
+
conv_mistral_instruct = Conversation(
|
411 |
+
system="",
|
412 |
+
roles=("USER", "ASSISTANT"),
|
413 |
+
version="llama_v2",
|
414 |
+
messages=(),
|
415 |
+
offset=0,
|
416 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
417 |
+
sep="",
|
418 |
+
sep2="</s>",
|
419 |
+
)
|
420 |
+
|
421 |
+
conv_chatml_direct = Conversation(
|
422 |
+
system="""<|im_start|>system
|
423 |
+
Answer the questions.""",
|
424 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
425 |
+
version="mpt",
|
426 |
+
messages=(),
|
427 |
+
offset=0,
|
428 |
+
sep_style=SeparatorStyle.MPT,
|
429 |
+
sep="<|im_end|>",
|
430 |
+
)
|
431 |
+
|
432 |
+
conv_llama3 = Conversation(
|
433 |
+
system="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.""",
|
434 |
+
roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
|
435 |
+
version="llama3",
|
436 |
+
messages=(),
|
437 |
+
offset=0,
|
438 |
+
sep_style=SeparatorStyle.MPT,
|
439 |
+
sep="<|eot_id|>",
|
440 |
+
)
|
441 |
+
|
442 |
+
conv_qwen = Conversation(
|
443 |
+
system="""<|im_start|>system
|
444 |
+
You are a helpful assistant.""",
|
445 |
+
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
446 |
+
version="qwen",
|
447 |
+
messages=[],
|
448 |
+
offset=0,
|
449 |
+
sep_style=SeparatorStyle.CHATML,
|
450 |
+
sep="<|im_end|>",
|
451 |
+
)
|
452 |
+
|
453 |
+
|
454 |
+
default_conversation = conv_llama3
|
455 |
+
conv_templates = {
|
456 |
+
"default": conv_vicuna_v0,
|
457 |
+
"v0": conv_vicuna_v0,
|
458 |
+
"v1": conv_vicuna_v1,
|
459 |
+
"vicuna_v1": conv_vicuna_v1,
|
460 |
+
"llama_2": conv_llama_2,
|
461 |
+
"mistral_instruct": conv_mistral_instruct,
|
462 |
+
"chatml_direct": conv_chatml_direct,
|
463 |
+
"mistral_direct": conv_chatml_direct,
|
464 |
+
|
465 |
+
"plain": conv_blip3o_plain,
|
466 |
+
"v0_plain": conv_blip3o_plain,
|
467 |
+
"blip3o_v0": conv_blip3o_v0,
|
468 |
+
"v0_mmtag": conv_blip3o_v0_mmtag,
|
469 |
+
"blip3o_v1": conv_blip3o_v1,
|
470 |
+
"v1_mmtag": conv_blip3o_v1_mmtag,
|
471 |
+
"blip3o_llama_2": conv_blip3o_llama_2,
|
472 |
+
"llama3": conv_llama3,
|
473 |
+
"qwen": conv_qwen,
|
474 |
+
|
475 |
+
"mpt": conv_mpt,
|
476 |
+
}
|
477 |
+
|
478 |
+
if __name__ == "__main__":
|
479 |
+
print(default_conversation.get_prompt())
|
blip3o/mm_utils.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from io import BytesIO
|
3 |
+
import base64
|
4 |
+
import torch
|
5 |
+
import math
|
6 |
+
import ast
|
7 |
+
|
8 |
+
from transformers import StoppingCriteria
|
9 |
+
from blip3o.constants import IMAGE_TOKEN_IDX
|
10 |
+
|
11 |
+
|
12 |
+
def select_best_resolution(original_size, possible_resolutions):
|
13 |
+
"""
|
14 |
+
Selects the best resolution from a list of possible resolutions based on the original size.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
original_size (tuple): The original size of the image in the format (width, height).
|
18 |
+
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
tuple: The best fit resolution in the format (width, height).
|
22 |
+
"""
|
23 |
+
original_width, original_height = original_size
|
24 |
+
best_fit = None
|
25 |
+
max_effective_resolution = 0
|
26 |
+
min_wasted_resolution = float('inf')
|
27 |
+
|
28 |
+
for width, height in possible_resolutions:
|
29 |
+
scale = min(width / original_width, height / original_height)
|
30 |
+
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
|
31 |
+
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
|
32 |
+
wasted_resolution = (width * height) - effective_resolution
|
33 |
+
|
34 |
+
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
|
35 |
+
max_effective_resolution = effective_resolution
|
36 |
+
min_wasted_resolution = wasted_resolution
|
37 |
+
best_fit = (width, height)
|
38 |
+
|
39 |
+
return best_fit
|
40 |
+
|
41 |
+
|
42 |
+
def resize_and_pad_image(image, target_resolution):
|
43 |
+
"""
|
44 |
+
Resize and pad an image to a target resolution while maintaining aspect ratio.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
image (PIL.Image.Image): The input image.
|
48 |
+
target_resolution (tuple): The target resolution (width, height) of the image.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
PIL.Image.Image: The resized and padded image.
|
52 |
+
"""
|
53 |
+
original_width, original_height = image.size
|
54 |
+
target_width, target_height = target_resolution
|
55 |
+
|
56 |
+
scale_w = target_width / original_width
|
57 |
+
scale_h = target_height / original_height
|
58 |
+
|
59 |
+
if scale_w < scale_h:
|
60 |
+
new_width = target_width
|
61 |
+
new_height = min(math.ceil(original_height * scale_w), target_height)
|
62 |
+
else:
|
63 |
+
new_height = target_height
|
64 |
+
new_width = min(math.ceil(original_width * scale_h), target_width)
|
65 |
+
|
66 |
+
# Resize the image
|
67 |
+
resized_image = image.resize((new_width, new_height))
|
68 |
+
|
69 |
+
new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
|
70 |
+
paste_x = (target_width - new_width) // 2
|
71 |
+
paste_y = (target_height - new_height) // 2
|
72 |
+
new_image.paste(resized_image, (paste_x, paste_y))
|
73 |
+
|
74 |
+
return new_image
|
75 |
+
|
76 |
+
|
77 |
+
def divide_to_patches(image, patch_size):
|
78 |
+
"""
|
79 |
+
Divides an image into patches of a specified size.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
image (PIL.Image.Image): The input image.
|
83 |
+
patch_size (int): The size of each patch.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
list: A list of PIL.Image.Image objects representing the patches.
|
87 |
+
"""
|
88 |
+
patches = []
|
89 |
+
width, height = image.size
|
90 |
+
for i in range(0, height, patch_size):
|
91 |
+
for j in range(0, width, patch_size):
|
92 |
+
box = (j, i, j + patch_size, i + patch_size)
|
93 |
+
patch = image.crop(box)
|
94 |
+
patches.append(patch)
|
95 |
+
|
96 |
+
return patches
|
97 |
+
|
98 |
+
|
99 |
+
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
100 |
+
"""
|
101 |
+
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
image_size (tuple): The size of the input image in the format (width, height).
|
105 |
+
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
106 |
+
patch_size (int): The size of each image patch.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
tuple: The shape of the image patch grid in the format (width, height).
|
110 |
+
"""
|
111 |
+
if type(grid_pinpoints) is list:
|
112 |
+
possible_resolutions = grid_pinpoints
|
113 |
+
else:
|
114 |
+
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
115 |
+
width, height = select_best_resolution(image_size, possible_resolutions)
|
116 |
+
return width // patch_size, height // patch_size
|
117 |
+
|
118 |
+
|
119 |
+
def process_anyres_image(image, processor, grid_pinpoints):
|
120 |
+
"""
|
121 |
+
Process an image with variable resolutions.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
image (PIL.Image.Image): The input image to be processed.
|
125 |
+
processor: The image processor object.
|
126 |
+
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
torch.Tensor: A tensor containing the processed image patches.
|
130 |
+
"""
|
131 |
+
if type(grid_pinpoints) is list:
|
132 |
+
possible_resolutions = grid_pinpoints
|
133 |
+
else:
|
134 |
+
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
135 |
+
best_resolution = select_best_resolution(image.size, possible_resolutions)
|
136 |
+
image_padded = resize_and_pad_image(image, best_resolution)
|
137 |
+
|
138 |
+
patches = divide_to_patches(image_padded, processor.crop_size['height'])
|
139 |
+
|
140 |
+
image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
|
141 |
+
|
142 |
+
image_patches = [image_original_resize] + patches
|
143 |
+
image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
|
144 |
+
for image_patch in image_patches]
|
145 |
+
return torch.stack(image_patches, dim=0)
|
146 |
+
|
147 |
+
|
148 |
+
def load_image_from_base64(image):
|
149 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
150 |
+
|
151 |
+
|
152 |
+
def expand2square(pil_img, background_color):
|
153 |
+
width, height = pil_img.size
|
154 |
+
if width == height:
|
155 |
+
return pil_img
|
156 |
+
elif width > height:
|
157 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
158 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
159 |
+
return result
|
160 |
+
else:
|
161 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
162 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
163 |
+
return result
|
164 |
+
|
165 |
+
|
166 |
+
def process_images(images, image_processor, model_cfg):
|
167 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
168 |
+
new_images = []
|
169 |
+
if image_aspect_ratio == 'pad':
|
170 |
+
for image in images:
|
171 |
+
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
|
172 |
+
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
173 |
+
new_images.append(image)
|
174 |
+
elif image_aspect_ratio == "anyres":
|
175 |
+
for image in images:
|
176 |
+
image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
|
177 |
+
new_images.append(image)
|
178 |
+
else:
|
179 |
+
return image_processor(images, return_tensors='pt')['pixel_values']
|
180 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
181 |
+
new_images = torch.stack(new_images, dim=0)
|
182 |
+
return new_images
|
183 |
+
|
184 |
+
|
185 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_IDX, return_tensors=None):
|
186 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
|
187 |
+
|
188 |
+
def insert_separator(X, sep):
|
189 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
190 |
+
|
191 |
+
input_ids = []
|
192 |
+
offset = 0
|
193 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
194 |
+
offset = 1
|
195 |
+
input_ids.append(prompt_chunks[0][0])
|
196 |
+
|
197 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
198 |
+
input_ids.extend(x[offset:])
|
199 |
+
|
200 |
+
if return_tensors is not None:
|
201 |
+
if return_tensors == 'pt':
|
202 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
203 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
204 |
+
return input_ids
|
205 |
+
|
206 |
+
|
207 |
+
def get_model_name_from_path(model_path):
|
208 |
+
model_path = model_path.strip("/")
|
209 |
+
model_paths = model_path.split("/")
|
210 |
+
if model_paths[-1].startswith('checkpoint-'):
|
211 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
212 |
+
else:
|
213 |
+
return model_paths[-1]
|
214 |
+
|
215 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
216 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
217 |
+
self.keywords = keywords
|
218 |
+
self.keyword_ids = []
|
219 |
+
self.max_keyword_len = 0
|
220 |
+
for keyword in keywords:
|
221 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
222 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
223 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
224 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
225 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
226 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
227 |
+
self.tokenizer = tokenizer
|
228 |
+
self.start_len = input_ids.shape[1]
|
229 |
+
|
230 |
+
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
231 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
232 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
233 |
+
for keyword_id in self.keyword_ids:
|
234 |
+
truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
|
235 |
+
if torch.equal(truncated_output_ids, keyword_id):
|
236 |
+
return True
|
237 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
238 |
+
for keyword in self.keywords:
|
239 |
+
if keyword in outputs:
|
240 |
+
return True
|
241 |
+
return False
|
242 |
+
|
243 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
244 |
+
outputs = []
|
245 |
+
for i in range(output_ids.shape[0]):
|
246 |
+
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
|
247 |
+
return all(outputs)
|
blip3o/model/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .language_model.blip3o_llama import blip3oLlamaForCausalLM, blip3oConfig
|
2 |
+
from .language_model.blip3o_qwen import blip3oQwenForCausalLM, blip3oQwenConfig
|
3 |
+
|
blip3o/model/apply_delta.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage:
|
3 |
+
python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
10 |
+
from blip3o import blip3oLlamaForCausalLM
|
11 |
+
|
12 |
+
|
13 |
+
def apply_delta(base_model_path, target_model_path, delta_path):
|
14 |
+
print("Loading base model")
|
15 |
+
base = AutoModelForCausalLM.from_pretrained(
|
16 |
+
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
17 |
+
|
18 |
+
print("Loading delta")
|
19 |
+
delta = blip3oLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
20 |
+
delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
|
21 |
+
|
22 |
+
print("Applying delta")
|
23 |
+
for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
|
24 |
+
if name not in base.state_dict():
|
25 |
+
assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
|
26 |
+
continue
|
27 |
+
if param.data.shape == base.state_dict()[name].shape:
|
28 |
+
param.data += base.state_dict()[name]
|
29 |
+
else:
|
30 |
+
assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
|
31 |
+
f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
|
32 |
+
bparam = base.state_dict()[name]
|
33 |
+
param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
|
34 |
+
|
35 |
+
print("Saving target model")
|
36 |
+
delta.save_pretrained(target_model_path)
|
37 |
+
delta_tokenizer.save_pretrained(target_model_path)
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
parser = argparse.ArgumentParser()
|
42 |
+
parser.add_argument("--base-model-path", type=str, required=True)
|
43 |
+
parser.add_argument("--target-model-path", type=str, required=True)
|
44 |
+
parser.add_argument("--delta-path", type=str, required=True)
|
45 |
+
|
46 |
+
args = parser.parse_args()
|
47 |
+
|
48 |
+
apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
|
blip3o/model/blip3o_arch.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from .multimodal_encoder.builder import build_vision_tower, build_gen_vision_tower, build_dit
|
8 |
+
from .multimodal_projector.builder import build_vision_projector, build_down_projector, build_gen_vision_projector
|
9 |
+
|
10 |
+
from blip3o.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_IDX, DEFAULT_IM_START_TOKEN_IDX, DEFAULT_IM_END_TOKEN_IDX, UND_IMAGE_TOKEN_IDX
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
class blip3oMetaModel:
|
15 |
+
|
16 |
+
def __init__(self, config):
|
17 |
+
super(blip3oMetaModel, self).__init__(config)
|
18 |
+
|
19 |
+
if hasattr(config, "mm_vision_tower"):
|
20 |
+
# self.vision_tower = build_vision_tower(config, delay_load=True)
|
21 |
+
# self.mm_projector = build_vision_projector(config)
|
22 |
+
self.down_projector = build_down_projector(config)
|
23 |
+
|
24 |
+
if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
|
25 |
+
self.image_newline = nn.Parameter(
|
26 |
+
torch.empty(config.hidden_size, dtype=self.dtype)
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
if hasattr(config, "gen_vision_tower"):
|
31 |
+
self.gen_vision_tower = build_gen_vision_tower(config, delay_load=True)
|
32 |
+
# self.gen_projector = build_gen_vision_projector(config)
|
33 |
+
self.latent_queries = nn.Parameter(torch.randn(1, config.n_query, config.hidden_size))
|
34 |
+
print(f" latent query size {self.latent_queries.shape}")
|
35 |
+
|
36 |
+
if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
|
37 |
+
self.image_newline = nn.Parameter(
|
38 |
+
torch.empty(config.hidden_size, dtype=self.dtype)
|
39 |
+
)
|
40 |
+
|
41 |
+
self.dit, self.vae, self.noise_scheduler = build_dit(config)
|
42 |
+
|
43 |
+
|
44 |
+
# def get_vision_tower(self):
|
45 |
+
# vision_tower = getattr(self, 'vision_tower', None)
|
46 |
+
# if type(vision_tower) is list:
|
47 |
+
# vision_tower = vision_tower[0]
|
48 |
+
# return vision_tower
|
49 |
+
|
50 |
+
|
51 |
+
def get_gen_vision_tower(self):
|
52 |
+
gen_vision_tower = getattr(self, 'gen_vision_tower', None)
|
53 |
+
if type(gen_vision_tower) is list:
|
54 |
+
gen_vision_tower = gen_vision_tower[0]
|
55 |
+
return gen_vision_tower
|
56 |
+
|
57 |
+
|
58 |
+
def initialize_vision_modules(self, model_args, fsdp=None):
|
59 |
+
gen_vision_tower = model_args.gen_vision_tower
|
60 |
+
|
61 |
+
mm_vision_select_layer = model_args.mm_vision_select_layer
|
62 |
+
mm_vision_select_feature = model_args.mm_vision_select_feature
|
63 |
+
|
64 |
+
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
65 |
+
pretrain_gen_mlp_adapter = model_args.pretrain_gen_mlp_adapter
|
66 |
+
|
67 |
+
mm_patch_merge_type = model_args.mm_patch_merge_type
|
68 |
+
|
69 |
+
self.config.gen_vision_tower = gen_vision_tower
|
70 |
+
self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "")
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
if getattr(self, 'dit', None) is None:
|
75 |
+
print("random initiation the DiT !!!")
|
76 |
+
self.dit, self.vae, self.noise_scheduler = build_dit(model_args)
|
77 |
+
else:
|
78 |
+
print("DiT load from checkpoint!!!")
|
79 |
+
for p in self.dit.parameters():
|
80 |
+
p.requires_grad = True
|
81 |
+
|
82 |
+
|
83 |
+
if self.get_gen_vision_tower() is None:
|
84 |
+
gen_vision_tower = build_gen_vision_tower(model_args)
|
85 |
+
|
86 |
+
if fsdp is not None and len(fsdp) > 0:
|
87 |
+
self.gen_vision_tower = [gen_vision_tower]
|
88 |
+
else:
|
89 |
+
self.gen_vision_tower = gen_vision_tower
|
90 |
+
else:
|
91 |
+
if fsdp is not None and len(fsdp) > 0:
|
92 |
+
gen_vision_tower = self.gen_vision_tower[0]
|
93 |
+
else:
|
94 |
+
gen_vision_tower = self.gen_vision_tower
|
95 |
+
gen_vision_tower.load_model()
|
96 |
+
|
97 |
+
|
98 |
+
self.config.use_mm_proj = True
|
99 |
+
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
|
100 |
+
# self.config.gen_projector_type = getattr(model_args, 'gen_projector_type', 'linear')
|
101 |
+
|
102 |
+
|
103 |
+
self.config.gen_hidden_size = gen_vision_tower.hidden_size
|
104 |
+
|
105 |
+
self.config.mm_vision_select_layer = mm_vision_select_layer
|
106 |
+
self.config.mm_vision_select_feature = mm_vision_select_feature
|
107 |
+
self.config.mm_patch_merge_type = mm_patch_merge_type
|
108 |
+
self.config.n_query = model_args.n_query
|
109 |
+
self.config.gen_pooling = model_args.gen_pooling
|
110 |
+
|
111 |
+
# if getattr(self, 'mm_projector', None) is None:
|
112 |
+
# print("random initiation the mm_project !!!")
|
113 |
+
# self.mm_projector = build_vision_projector(self.config)
|
114 |
+
|
115 |
+
# if 'unpad' in mm_patch_merge_type:
|
116 |
+
# embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
|
117 |
+
# self.image_newline = nn.Parameter(
|
118 |
+
# torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
|
119 |
+
# )
|
120 |
+
# else:
|
121 |
+
# # In case it is frozen by LoRA
|
122 |
+
# for p in self.mm_projector.parameters():
|
123 |
+
# p.requires_grad = True
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
if getattr(self, 'down_projector', None) is None:
|
128 |
+
print("random initiation the down_projector !!!")
|
129 |
+
self.down_projector = build_down_projector(self.config)
|
130 |
+
else:
|
131 |
+
# In case it is frozen by LoRA
|
132 |
+
for p in self.down_projector.parameters():
|
133 |
+
p.requires_grad = True
|
134 |
+
|
135 |
+
if getattr(self, 'latent_queries', None) is None:
|
136 |
+
print("random initiation the latent_queries !!!")
|
137 |
+
self.latent_queries = nn.Parameter(torch.randn(1, self.config.n_query, self.config.hidden_size))
|
138 |
+
else:
|
139 |
+
print("latent_queries load from checkpoint!!!")
|
140 |
+
self.latent_queries.requires_grad = True
|
141 |
+
|
142 |
+
|
143 |
+
if pretrain_mm_mlp_adapter is not None:
|
144 |
+
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
|
145 |
+
def get_w(weights, keyword):
|
146 |
+
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
147 |
+
|
148 |
+
# self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
|
149 |
+
|
150 |
+
|
151 |
+
|
152 |
+
def unpad_image(tensor, original_size):
|
153 |
+
"""
|
154 |
+
Unpads a PyTorch tensor of a padded and resized image.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
|
158 |
+
original_size (tuple): The original size of PIL image (width, height).
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
torch.Tensor: The unpadded image tensor.
|
162 |
+
"""
|
163 |
+
original_width, original_height = original_size
|
164 |
+
current_height, current_width = tensor.shape[1:]
|
165 |
+
|
166 |
+
original_aspect_ratio = original_width / original_height
|
167 |
+
current_aspect_ratio = current_width / current_height
|
168 |
+
|
169 |
+
if original_aspect_ratio > current_aspect_ratio:
|
170 |
+
scale_factor = current_width / original_width
|
171 |
+
new_height = int(original_height * scale_factor)
|
172 |
+
padding = (current_height - new_height) // 2
|
173 |
+
unpadded_tensor = tensor[:, padding:current_height - padding, :]
|
174 |
+
else:
|
175 |
+
scale_factor = current_height / original_height
|
176 |
+
new_width = int(original_width * scale_factor)
|
177 |
+
padding = (current_width - new_width) // 2
|
178 |
+
unpadded_tensor = tensor[:, :, padding:current_width - padding]
|
179 |
+
|
180 |
+
return unpadded_tensor
|
181 |
+
|
182 |
+
|
183 |
+
class blip3oMetaForCausalLM(ABC):
|
184 |
+
|
185 |
+
@abstractmethod
|
186 |
+
def get_model(self):
|
187 |
+
pass
|
188 |
+
|
189 |
+
def get_vision_tower(self):
|
190 |
+
return self.get_model().get_vision_tower()
|
191 |
+
|
192 |
+
def get_gen_vision_tower(self):
|
193 |
+
return self.get_model().get_gen_vision_tower()
|
194 |
+
|
195 |
+
def encode_image(self, images):
|
196 |
+
# breakpoint()
|
197 |
+
gen_vision_tower = self.get_gen_vision_tower()
|
198 |
+
device = gen_vision_tower.device
|
199 |
+
images = images.to(device)
|
200 |
+
prompt_image_embeds = gen_vision_tower(images)
|
201 |
+
if 'early' in self.get_gen_pooling():
|
202 |
+
prompt_image_embeds = self.pool_img(prompt_image_embeds)
|
203 |
+
num_img, _, c = prompt_image_embeds.shape
|
204 |
+
# prompt_image_embeds = prompt_image_embeds.contiguous().view(-1, c)
|
205 |
+
|
206 |
+
# ------------- compute similarity -------
|
207 |
+
all_dist = 0
|
208 |
+
count = 0
|
209 |
+
for i in range(2, prompt_image_embeds.shape[1]-1):
|
210 |
+
diff = (prompt_image_embeds[:,i,:].unsqueeze(1) - prompt_image_embeds[:,:i,:])
|
211 |
+
dist = torch.sqrt(diff.square().sum(-1)).min().item()
|
212 |
+
all_dist+=dist
|
213 |
+
count+=1
|
214 |
+
all_dist /= count
|
215 |
+
# self.dist = all_dist
|
216 |
+
# print(self.dist)
|
217 |
+
|
218 |
+
return prompt_image_embeds
|
219 |
+
|
220 |
+
def get_mm_projector(self):
|
221 |
+
return self.get_model().mm_projector
|
222 |
+
|
223 |
+
def get_gen_projector(self):
|
224 |
+
return None
|
225 |
+
|
226 |
+
|
227 |
+
def get_n_query(self):
|
228 |
+
return self.get_model().config.n_query
|
229 |
+
|
230 |
+
def get_gen_pooling(self):
|
231 |
+
return self.get_model().config.gen_pooling
|
232 |
+
|
233 |
+
def pool_img(self, image_features):
|
234 |
+
num_img, n, c = image_features.shape
|
235 |
+
gen_pooling = self.get_gen_pooling()
|
236 |
+
# n_query = self.get_n_query()
|
237 |
+
stride = int(gen_pooling.split('_')[-1])
|
238 |
+
sqrt_n = int(n**0.5)
|
239 |
+
image_features = image_features.permute(0, 2, 1).view(num_img, c, sqrt_n, sqrt_n)
|
240 |
+
image_features = F.avg_pool2d(image_features, kernel_size=(stride, stride), stride=stride)
|
241 |
+
# image_features = image_features.view(num_img, c, -1).permute(0,2,1).contiguous()
|
242 |
+
return image_features
|
243 |
+
|
244 |
+
def get_sigmas(self, timesteps, device, n_dim=4, dtype=torch.float32):
|
245 |
+
sigmas = self.get_model().noise_scheduler.sigmas.to(device=device, dtype=dtype)
|
246 |
+
schedule_timesteps = self.get_model().noise_scheduler.timesteps.to(device=device)
|
247 |
+
timesteps = timesteps.to(device)
|
248 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
249 |
+
|
250 |
+
sigma = sigmas[step_indices].flatten()
|
251 |
+
while len(sigma.shape) < n_dim:
|
252 |
+
sigma = sigma.unsqueeze(-1)
|
253 |
+
return sigma
|
254 |
+
|
255 |
+
def mask_drop(self, latents, drop_prob=0.1):
|
256 |
+
if drop_prob <= 0:
|
257 |
+
return latents
|
258 |
+
mask = torch.bernoulli(torch.zeros(latents.shape[0], device=latents.device, dtype=latents.dtype) + drop_prob)
|
259 |
+
while len(mask.shape) < len(latents.shape):
|
260 |
+
mask = mask.unsqueeze(-1)
|
261 |
+
mask = 1 - mask # need to flip 0 <-> 1
|
262 |
+
return latents * mask
|
263 |
+
|
264 |
+
def prepare_inputs_labels_for_multimodal(
|
265 |
+
self, input_ids, position_ids, attention_mask, past_key_values, labels,
|
266 |
+
gen_images, und_images, grid_thw, i_s_pos, image_sizes=None
|
267 |
+
):
|
268 |
+
pad_ids = 128256
|
269 |
+
vision_tower = self.visual
|
270 |
+
gen_vision_tower = self.get_gen_vision_tower()
|
271 |
+
if (gen_images is None and und_images is None) or input_ids.shape[1] == 1:
|
272 |
+
return input_ids, position_ids, attention_mask, past_key_values, None, labels, None, None, None
|
273 |
+
|
274 |
+
|
275 |
+
|
276 |
+
|
277 |
+
prompt_image_embeds = gen_vision_tower(gen_images) # TODO: check dimension
|
278 |
+
|
279 |
+
if 'early' in self.get_gen_pooling():
|
280 |
+
prompt_image_embeds = self.pool_img(prompt_image_embeds)
|
281 |
+
target_image_embeds = torch.clone(prompt_image_embeds).detach()
|
282 |
+
latent_queries = self.get_model().latent_queries.repeat(input_ids.shape[0], 1, 1)
|
283 |
+
H = latent_queries.shape[-1]
|
284 |
+
latent_queries = latent_queries.contiguous().view(-1, H)
|
285 |
+
|
286 |
+
|
287 |
+
|
288 |
+
|
289 |
+
# if not gen_images is None:
|
290 |
+
# prompt_image_embeds = gen_vision_tower(gen_images) # TODO: check dimension
|
291 |
+
# if 'early' in self.get_gen_pooling():
|
292 |
+
# prompt_image_embeds = self.pool_img(prompt_image_embeds)
|
293 |
+
# # num_img, _, c = prompt_image_embeds.shape # [batch, 729, 1152]
|
294 |
+
# # prompt_image_embeds = prompt_image_embeds.contiguous().view(-1, c)
|
295 |
+
# target_image_embeds = torch.clone(prompt_image_embeds).detach()
|
296 |
+
# # prompt_image_embeds = gen_projector(prompt_image_embeds)
|
297 |
+
# latent_queries = self.get_model().latent_queries.repeat(input_ids.shape[0], 1, 1)
|
298 |
+
# H = latent_queries.shape[-1]
|
299 |
+
# latent_queries = latent_queries.contiguous().view(-1, H)
|
300 |
+
# else:
|
301 |
+
# target_image_embeds = None
|
302 |
+
# num_img = und_images.shape[0]
|
303 |
+
# dummy = torch.zeros(num_img, 3, 448, 448 , dtype=und_images.dtype, device=und_images.device) # TODO
|
304 |
+
# temp = gen_vision_tower(dummy)[:,:729,:]
|
305 |
+
# num_img, _, c = temp.shape
|
306 |
+
# temp = temp.contiguous().view(-1, c) * 1e-20
|
307 |
+
# # temp = gen_projector(temp) * 1e-9
|
308 |
+
# latent_queries = self.get_model().latent_queries.repeat(input_ids.shape[0], 1, 1)
|
309 |
+
# H = latent_queries.shape[-1]
|
310 |
+
# latent_queries = latent_queries.contiguous().view(-1, H)
|
311 |
+
|
312 |
+
|
313 |
+
if not und_images is None:
|
314 |
+
und_image_embeds = vision_tower(und_images, grid_thw=grid_thw)
|
315 |
+
# _, c = und_image_embeds.shape
|
316 |
+
# batch_size = und_images.shape[0]
|
317 |
+
# und_image_embeds = und_image_embeds.view(batch_size, -1, c)
|
318 |
+
# und_image_embeds = und_image_embeds.contiguous().view(-1, c)
|
319 |
+
# und_image_embeds = mm_projector(und_image_embeds)
|
320 |
+
|
321 |
+
# else:
|
322 |
+
# num_img = input_ids.shape[0]
|
323 |
+
# dummy = torch.zeros(num_img, 3, 384, 384 , dtype=gen_images.dtype, device=gen_images.device) # clip (3, 336, 336)
|
324 |
+
# temp = vision_tower(dummy)
|
325 |
+
# if 'early' in self.get_gen_pooling():
|
326 |
+
# temp = temp[:,:64,:]
|
327 |
+
# num_img, _, c = temp.shape
|
328 |
+
# temp = temp.contiguous().view(-1, c)
|
329 |
+
# temp = mm_projector(temp) * 1e-20
|
330 |
+
# latent_queries += temp
|
331 |
+
|
332 |
+
|
333 |
+
|
334 |
+
|
335 |
+
|
336 |
+
image_idx = (input_ids == IMAGE_TOKEN_IDX)
|
337 |
+
und_image_idx = (input_ids == UND_IMAGE_TOKEN_IDX)
|
338 |
+
# img_indicator = torch.clone(image_idx)
|
339 |
+
output_indicator = labels != -100
|
340 |
+
input_indicator = labels == -100
|
341 |
+
# img_loss_indicator = torch.logical_and(output_indicator, image_idx)
|
342 |
+
# img_loss_indicator = torch.cat(
|
343 |
+
# [img_loss_indicator[:, 1:], img_loss_indicator[:, :1]], dim=1)
|
344 |
+
|
345 |
+
# img_indicator = torch.cat(
|
346 |
+
# [img_indicator[:, 1:], img_indicator[:, :1]], dim=1)
|
347 |
+
|
348 |
+
# if not target_image_embeds is None:
|
349 |
+
# target_image_embeds = target_image_embeds[-img_loss_indicator.sum():,:]
|
350 |
+
text_embeds = self.get_model().embed_tokens(input_ids)
|
351 |
+
# N_QUERY = self.get_n_query()
|
352 |
+
gen_img_idx = torch.logical_and(output_indicator, image_idx)
|
353 |
+
|
354 |
+
# if not target_image_embeds is None:
|
355 |
+
text_embeds = text_embeds.clone()
|
356 |
+
text_embeds[gen_img_idx] = latent_queries
|
357 |
+
# text_embeds[gen_img_idx] = prompt_image_embeds.to(text_embeds.device)[:gen_img_idx.sum(),:]
|
358 |
+
# target_image_embeds = target_image_embeds.to(text_embeds.device)[:gen_img_idx.sum(),:]
|
359 |
+
|
360 |
+
und_img_idx = torch.logical_and(input_indicator, und_image_idx)
|
361 |
+
|
362 |
+
|
363 |
+
if not und_images is None:
|
364 |
+
text_embeds[und_img_idx] = und_image_embeds.to(text_embeds.device)[:und_img_idx.sum(), :]
|
365 |
+
|
366 |
+
labels[image_idx] = -100
|
367 |
+
|
368 |
+
|
369 |
+
return None, position_ids, attention_mask, past_key_values, text_embeds, labels, target_image_embeds
|
370 |
+
|
371 |
+
|
372 |
+
|
373 |
+
def initialize_vision_tokenizer(self, model_args, tokenizer):
|
374 |
+
if model_args.mm_use_im_patch_token:
|
375 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
376 |
+
self.resize_token_embeddings(len(tokenizer))
|
377 |
+
|
378 |
+
if model_args.mm_use_im_start_end:
|
379 |
+
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
380 |
+
self.resize_token_embeddings(len(tokenizer))
|
381 |
+
|
382 |
+
if num_new_tokens > 0:
|
383 |
+
input_embeddings = self.get_input_embeddings().weight.data
|
384 |
+
output_embeddings = self.get_output_embeddings().weight.data
|
385 |
+
|
386 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
387 |
+
dim=0, keepdim=True)
|
388 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
389 |
+
dim=0, keepdim=True)
|
390 |
+
|
391 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
392 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
393 |
+
|
394 |
+
if model_args.tune_mm_mlp_adapter:
|
395 |
+
for p in self.get_input_embeddings().parameters():
|
396 |
+
p.requires_grad = True
|
397 |
+
for p in self.get_output_embeddings().parameters():
|
398 |
+
p.requires_grad = False
|
399 |
+
|
400 |
+
if model_args.pretrain_mm_mlp_adapter:
|
401 |
+
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
|
402 |
+
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
|
403 |
+
assert num_new_tokens == 2
|
404 |
+
if input_embeddings.shape == embed_tokens_weight.shape:
|
405 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
|
406 |
+
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
407 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
408 |
+
else:
|
409 |
+
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
|
410 |
+
elif model_args.mm_use_im_patch_token:
|
411 |
+
if model_args.tune_mm_mlp_adapter:
|
412 |
+
for p in self.get_input_embeddings().parameters():
|
413 |
+
p.requires_grad = False
|
414 |
+
for p in self.get_output_embeddings().parameters():
|
415 |
+
p.requires_grad = False
|
blip3o/model/builder.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import warnings
|
3 |
+
import shutil
|
4 |
+
|
5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
|
6 |
+
import torch
|
7 |
+
from blip3o.model import *
|
8 |
+
from blip3o.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
9 |
+
from blip3o.train.train import smart_tokenizer_and_embedding_resize
|
10 |
+
|
11 |
+
|
12 |
+
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
|
13 |
+
kwargs = {"device_map": device_map, **kwargs}
|
14 |
+
|
15 |
+
if device != "cuda":
|
16 |
+
kwargs['device_map'] = {"": device}
|
17 |
+
|
18 |
+
if load_8bit:
|
19 |
+
kwargs['load_in_8bit'] = True
|
20 |
+
elif load_4bit:
|
21 |
+
kwargs['load_in_4bit'] = True
|
22 |
+
kwargs['quantization_config'] = BitsAndBytesConfig(
|
23 |
+
load_in_4bit=True,
|
24 |
+
bnb_4bit_compute_dtype=torch.float16,
|
25 |
+
bnb_4bit_use_double_quant=True,
|
26 |
+
bnb_4bit_quant_type='nf4'
|
27 |
+
)
|
28 |
+
else:
|
29 |
+
kwargs['torch_dtype'] = torch.float16
|
30 |
+
|
31 |
+
if use_flash_attn:
|
32 |
+
kwargs['attn_implementation'] = 'flash_attention_2'
|
33 |
+
|
34 |
+
|
35 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
36 |
+
|
37 |
+
model = blip3oQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=torch.float16).to('cuda:0')
|
38 |
+
|
39 |
+
image_processor = None
|
40 |
+
|
41 |
+
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
42 |
+
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
|
43 |
+
if mm_use_im_patch_token:
|
44 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
45 |
+
if mm_use_im_start_end:
|
46 |
+
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
47 |
+
model.resize_token_embeddings(len(tokenizer))
|
48 |
+
|
49 |
+
if hasattr(model.config, "max_sequence_length"):
|
50 |
+
context_len = model.config.max_sequence_length
|
51 |
+
else:
|
52 |
+
context_len = 2048
|
53 |
+
|
54 |
+
return tokenizer, model, context_len
|
blip3o/model/consolidate.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
+
from blip3o.model import *
|
6 |
+
from blip3o.model.utils import auto_upgrade
|
7 |
+
|
8 |
+
|
9 |
+
def consolidate_ckpt(src_path, dst_path):
|
10 |
+
print("Loading model")
|
11 |
+
auto_upgrade(src_path)
|
12 |
+
src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
13 |
+
src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
|
14 |
+
src_model.save_pretrained(dst_path)
|
15 |
+
src_tokenizer.save_pretrained(dst_path)
|
16 |
+
|
17 |
+
|
18 |
+
if __name__ == "__main__":
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument("--src", type=str, required=True)
|
21 |
+
parser.add_argument("--dst", type=str, required=True)
|
22 |
+
|
23 |
+
args = parser.parse_args()
|
24 |
+
|
25 |
+
consolidate_ckpt(args.src, args.dst)
|
blip3o/model/language_model/blip3o_llama.py
ADDED
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple, Union
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from transformers import AutoConfig, AutoModelForCausalLM, \
|
9 |
+
LlamaConfig, LlamaModel, LlamaForCausalLM, AutoTokenizer
|
10 |
+
|
11 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
12 |
+
from transformers.generation.utils import GenerateOutput
|
13 |
+
|
14 |
+
from ..blip3o_arch import blip3oMetaModel, blip3oMetaForCausalLM
|
15 |
+
from blip3o.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_IDX, DEFAULT_IM_START_TOKEN_IDX, DEFAULT_IM_END_TOKEN_IDX
|
16 |
+
import pdb
|
17 |
+
from diffusers.utils.torch_utils import randn_tensor
|
18 |
+
from diffusers.pipelines.pipeline_utils import numpy_to_pil
|
19 |
+
import numpy as np
|
20 |
+
from diffusers.models import AutoencoderKL
|
21 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
22 |
+
|
23 |
+
|
24 |
+
class blip3oConfig(LlamaConfig):
|
25 |
+
model_type = "blip3o_llama"
|
26 |
+
|
27 |
+
|
28 |
+
class blip3oLlamaModel(blip3oMetaModel, LlamaModel):
|
29 |
+
config_class = blip3oConfig
|
30 |
+
|
31 |
+
def __init__(self, config: LlamaConfig):
|
32 |
+
super(blip3oLlamaModel, self).__init__(config)
|
33 |
+
|
34 |
+
|
35 |
+
class blip3oLlamaForCausalLM(LlamaForCausalLM, blip3oMetaForCausalLM):
|
36 |
+
config_class = blip3oConfig
|
37 |
+
|
38 |
+
def __init__(self, config):
|
39 |
+
super(LlamaForCausalLM, self).__init__(config)
|
40 |
+
self.model = blip3oLlamaModel(config)
|
41 |
+
self.pretraining_tp = config.pretraining_tp
|
42 |
+
self.vocab_size = config.vocab_size
|
43 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
44 |
+
self.dist = None
|
45 |
+
|
46 |
+
# Initialize weights and apply final processing
|
47 |
+
self.post_init()
|
48 |
+
|
49 |
+
def get_model(self):
|
50 |
+
return self.model
|
51 |
+
|
52 |
+
def forward(
|
53 |
+
self,
|
54 |
+
input_ids: torch.LongTensor = None,
|
55 |
+
attention_mask: Optional[torch.Tensor] = None,
|
56 |
+
position_ids: Optional[torch.LongTensor] = None,
|
57 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
58 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
59 |
+
labels: Optional[torch.LongTensor] = None,
|
60 |
+
ids: Optional[list] = None,
|
61 |
+
i_s_pos: Optional[list] = None,
|
62 |
+
image_type: Optional[torch.Tensor] = None,
|
63 |
+
use_cache: Optional[bool] = None,
|
64 |
+
output_attentions: Optional[bool] = None,
|
65 |
+
output_hidden_states: Optional[bool] = None,
|
66 |
+
gen_image: Optional[torch.FloatTensor] = None,
|
67 |
+
und_image: Optional[torch.FloatTensor] = None,
|
68 |
+
image_sizes: Optional[List[List[int]]] = None,
|
69 |
+
return_dict: Optional[bool] = None,
|
70 |
+
cache_position: Optional[torch.LongTensor] = None
|
71 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
72 |
+
|
73 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
74 |
+
output_hidden_states = (
|
75 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
76 |
+
)
|
77 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
78 |
+
|
79 |
+
if inputs_embeds is None:
|
80 |
+
(
|
81 |
+
input_ids,
|
82 |
+
position_ids,
|
83 |
+
attention_mask,
|
84 |
+
past_key_values,
|
85 |
+
inputs_embeds,
|
86 |
+
labels,
|
87 |
+
latents
|
88 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
89 |
+
input_ids,
|
90 |
+
position_ids,
|
91 |
+
attention_mask,
|
92 |
+
past_key_values,
|
93 |
+
labels,
|
94 |
+
gen_image,
|
95 |
+
und_image,
|
96 |
+
i_s_pos,
|
97 |
+
image_sizes
|
98 |
+
)
|
99 |
+
|
100 |
+
outputs = self.model(
|
101 |
+
input_ids=input_ids,
|
102 |
+
attention_mask=attention_mask,
|
103 |
+
position_ids=position_ids,
|
104 |
+
past_key_values=past_key_values,
|
105 |
+
inputs_embeds=inputs_embeds,
|
106 |
+
use_cache=use_cache,
|
107 |
+
output_attentions=output_attentions,
|
108 |
+
output_hidden_states=output_hidden_states,
|
109 |
+
return_dict=return_dict,
|
110 |
+
)
|
111 |
+
|
112 |
+
hidden_states = outputs[0]
|
113 |
+
logits = self.lm_head(hidden_states)
|
114 |
+
logits = logits.float()
|
115 |
+
|
116 |
+
total_loss = None
|
117 |
+
if labels is not None:
|
118 |
+
# Shift so that tokens < n predict n
|
119 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
120 |
+
shift_labels = labels[..., 1:].contiguous()
|
121 |
+
# Flatten the tokens
|
122 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
123 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
124 |
+
shift_labels = shift_labels.view(-1)
|
125 |
+
# Enable model parallelism
|
126 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
127 |
+
loss = loss_fct(shift_logits, shift_labels)
|
128 |
+
|
129 |
+
|
130 |
+
# compute image loss
|
131 |
+
# target_img_embeds = torch.clone(inputs_embeds.detach())[:,1:,:] # get target image emb
|
132 |
+
img_loss_funct = torch.nn.MSELoss()
|
133 |
+
# img_hidden_states = self.get_model().down_projector(hidden_states[:,-self.get_n_query():,:])
|
134 |
+
img_hidden_states = []
|
135 |
+
|
136 |
+
for b in range(hidden_states.shape[0]):
|
137 |
+
img_hidden_states.append(hidden_states[b,i_s_pos[b]:i_s_pos[b]+64,:])
|
138 |
+
img_hidden_states = torch.stack(img_hidden_states,dim=0)
|
139 |
+
img_hidden_states = self.get_model().down_projector(img_hidden_states)
|
140 |
+
# img_loss = 0.0
|
141 |
+
if latents is None:
|
142 |
+
img_loss = img_loss_funct(img_hidden_states, torch.clone(img_hidden_states.detach()))
|
143 |
+
else:
|
144 |
+
bsz = latents.shape[0]
|
145 |
+
# device = latents.device
|
146 |
+
dtype = latents.dtype
|
147 |
+
noise = torch.randn_like(latents, device=latents.device)
|
148 |
+
u = torch.rand(size=(bsz,), device="cpu")
|
149 |
+
indices = (u * self.get_model().noise_scheduler.config.num_train_timesteps).long()
|
150 |
+
timesteps = self.get_model().noise_scheduler.timesteps[indices].to(device=latents.device)
|
151 |
+
sigmas = self.get_sigmas(timesteps, latents.device, n_dim=latents.ndim, dtype=dtype)
|
152 |
+
noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
|
153 |
+
noise_pred = self.get_model().dit(
|
154 |
+
x=noisy_latents,
|
155 |
+
timestep=timesteps,
|
156 |
+
z_latents=self.mask_drop(img_hidden_states),
|
157 |
+
)
|
158 |
+
target = noise - latents
|
159 |
+
img_loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
160 |
+
print(f"img loss {img_loss}, text loss {loss}")
|
161 |
+
total_loss = img_loss
|
162 |
+
|
163 |
+
return CausalLMOutputWithPast(
|
164 |
+
loss=total_loss,
|
165 |
+
logits=logits,
|
166 |
+
past_key_values=outputs.past_key_values,
|
167 |
+
hidden_states=outputs.hidden_states,
|
168 |
+
attentions=outputs.attentions,
|
169 |
+
)
|
170 |
+
|
171 |
+
|
172 |
+
@torch.no_grad()
|
173 |
+
def generate(
|
174 |
+
self,
|
175 |
+
inputs: Optional[torch.Tensor] = None,
|
176 |
+
images: Optional[torch.Tensor] = None,
|
177 |
+
image_sizes: Optional[torch.Tensor] = None,
|
178 |
+
**kwargs,
|
179 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
180 |
+
position_ids = kwargs.pop("position_ids", None)
|
181 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
182 |
+
if "inputs_embeds" in kwargs:
|
183 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
184 |
+
|
185 |
+
if images is not None:
|
186 |
+
(
|
187 |
+
inputs,
|
188 |
+
position_ids,
|
189 |
+
attention_mask,
|
190 |
+
_,
|
191 |
+
inputs_embeds,
|
192 |
+
img_indicator,
|
193 |
+
_
|
194 |
+
) = self.prepare_inputs_labels_for_understanding(
|
195 |
+
inputs,
|
196 |
+
position_ids,
|
197 |
+
attention_mask,
|
198 |
+
None,
|
199 |
+
None,
|
200 |
+
images,
|
201 |
+
image_sizes=image_sizes
|
202 |
+
)
|
203 |
+
else:
|
204 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
205 |
+
|
206 |
+
return super().generate(
|
207 |
+
position_ids=position_ids,
|
208 |
+
attention_mask=attention_mask,
|
209 |
+
inputs_embeds=inputs_embeds,
|
210 |
+
**kwargs
|
211 |
+
)
|
212 |
+
|
213 |
+
@torch.no_grad()
|
214 |
+
def generate_image(
|
215 |
+
self,
|
216 |
+
text: List[str],
|
217 |
+
tokenizer: AutoTokenizer,
|
218 |
+
image: Optional[torch.Tensor] = None,
|
219 |
+
max_var: Optional[float] = None,
|
220 |
+
# placeholder: str = DEFAULT_IMG_PLACEHOLDER,
|
221 |
+
):
|
222 |
+
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler")
|
223 |
+
|
224 |
+
vision_tower = self.get_vision_tower()
|
225 |
+
mm_projector = self.get_mm_projector()
|
226 |
+
N_QUERY = self.get_n_query()
|
227 |
+
|
228 |
+
if image is not None:
|
229 |
+
# image: [Batch, 3, 448, 448]
|
230 |
+
prompt_image_embeds = vision_tower(batch_images)
|
231 |
+
num_img, _, c = prompt_image_embeds.shape # [batch, 576, 1024]
|
232 |
+
all_image_embeds = torch.clone(prompt_image_embeds).detach()
|
233 |
+
prompt_image_embeds = prompt_image_embeds.contiguous().view(-1, c)
|
234 |
+
prompt_image_embeds = mm_projector(prompt_image_embeds)
|
235 |
+
|
236 |
+
inputs = tokenizer(text, padding="longest", return_tensors="pt")
|
237 |
+
device = self.get_model().device
|
238 |
+
attention_mask = inputs.attention_mask.to(device)
|
239 |
+
input_ids = inputs.input_ids.to(device) # B x N
|
240 |
+
input_ids = torch.cat([input_ids, torch.tensor([[198]]).to(device)], dim=1)
|
241 |
+
|
242 |
+
# breakpoint()
|
243 |
+
text_embeds = self.get_model().embed_tokens(input_ids)
|
244 |
+
latent_queries = self.get_model().latent_queries.repeat(text_embeds.shape[0], 1, 1)
|
245 |
+
text_embeds = torch.cat([text_embeds, latent_queries], dim=1)
|
246 |
+
attention_mask = torch.cat([attention_mask, torch.ones_like(latent_queries[:, :, 0])], dim=1)
|
247 |
+
|
248 |
+
outputs = self.model(
|
249 |
+
inputs_embeds=text_embeds,
|
250 |
+
# img_indicator=img_indicator,
|
251 |
+
# concept_indicator=concept_indicator if self.use_concept_token else None,
|
252 |
+
attention_mask=attention_mask,
|
253 |
+
output_hidden_states=True,
|
254 |
+
return_dict=True,
|
255 |
+
)
|
256 |
+
hidden_states = outputs.hidden_states[-1][:,-N_QUERY:,:]
|
257 |
+
img_hidden_states = self.get_model().down_projector(hidden_states)
|
258 |
+
output_img = self.sample_images(img_hidden_states, scheduler)
|
259 |
+
output_img = output_img.view(1, 1792, -1).permute(0,2,1).contiguous()
|
260 |
+
|
261 |
+
return output_img
|
262 |
+
|
263 |
+
def sample_images(
|
264 |
+
self,
|
265 |
+
img_hidden_states,
|
266 |
+
scheduler,
|
267 |
+
guidance_scale: float = 3.0,
|
268 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
269 |
+
num_inference_steps: int = 30,
|
270 |
+
num_images_per_prompt: int = 1,
|
271 |
+
return_tensor=False,
|
272 |
+
**kwargs,
|
273 |
+
):
|
274 |
+
|
275 |
+
device = img_hidden_states.device
|
276 |
+
dtype = img_hidden_states.dtype
|
277 |
+
|
278 |
+
img_hidden_states_null = torch.zeros_like(img_hidden_states, device=device, dtype=dtype)
|
279 |
+
img_hidden_states_input = torch.cat([img_hidden_states_null, img_hidden_states], 0)
|
280 |
+
|
281 |
+
batch_size = img_hidden_states.shape[0]
|
282 |
+
latent_size = self.get_model().dit.config.input_size
|
283 |
+
latent_channels = self.get_model().dit.config.in_channels
|
284 |
+
|
285 |
+
latents = randn_tensor(
|
286 |
+
shape=(batch_size * num_images_per_prompt, latent_channels, latent_size, latent_size),
|
287 |
+
generator=generator,
|
288 |
+
device=device,
|
289 |
+
dtype=dtype,
|
290 |
+
)
|
291 |
+
|
292 |
+
# set step values
|
293 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
294 |
+
scheduler.set_timesteps(num_inference_steps, sigmas=sigmas)
|
295 |
+
|
296 |
+
# Repeat z_latents and conditions for each image per prompt
|
297 |
+
img_hidden_states_input = img_hidden_states_input.repeat_interleave(num_images_per_prompt, dim=0)
|
298 |
+
|
299 |
+
for t in scheduler.timesteps:
|
300 |
+
latent_model_input = latents.repeat(2, 1, 1, 1)
|
301 |
+
if hasattr(scheduler, "scale_model_input"):
|
302 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
303 |
+
|
304 |
+
# predict noise model_output
|
305 |
+
noise_pred = self.get_model().dit(
|
306 |
+
x=latent_model_input,
|
307 |
+
timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]).to(latent_model_input.device, torch.long),
|
308 |
+
z_latents=img_hidden_states_input,
|
309 |
+
)
|
310 |
+
|
311 |
+
# perform guidance
|
312 |
+
noise_pred_uncond, noise_pred = noise_pred.chunk(2)
|
313 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
|
314 |
+
|
315 |
+
# compute previous image: x_t -> x_t-1
|
316 |
+
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
317 |
+
|
318 |
+
# samples = self.decode_latents(latents, return_tensor=return_tensor)
|
319 |
+
return latents
|
320 |
+
|
321 |
+
def decode_latents(self, latents, normalize=True, return_tensor=False):
|
322 |
+
if isinstance(self.get_model().vae, AutoencoderKL):
|
323 |
+
latents = latents / self.get_model().vae.config.scaling_factor
|
324 |
+
if self.get_model().vae.config.shift_factor is not None:
|
325 |
+
latents = latents + self.get_model().vae.config.shift_factor
|
326 |
+
latents = latents.to(dtype=torch.float32)
|
327 |
+
samples = self.get_model().vae.decode(latents).sample
|
328 |
+
else:
|
329 |
+
samples = self.get_model().vae.decode(latents)
|
330 |
+
if normalize:
|
331 |
+
samples = (samples / 2 + 0.5).clamp(0, 1)
|
332 |
+
else:
|
333 |
+
samples = samples.clamp(-1, 1)
|
334 |
+
if return_tensor:
|
335 |
+
return samples
|
336 |
+
samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
|
337 |
+
samples = numpy_to_pil(samples)
|
338 |
+
return samples
|
339 |
+
|
340 |
+
def prepare_and_encode_inputs(
|
341 |
+
self,
|
342 |
+
inputs: List[str | Image.Image],
|
343 |
+
tokenizer: AutoTokenizer,
|
344 |
+
do_classifier_free_guidance: bool = False,
|
345 |
+
):
|
346 |
+
# pdb.set_trace()
|
347 |
+
device = self.get_model().device
|
348 |
+
dtype = self.get_model().dtype
|
349 |
+
|
350 |
+
has_image, has_text = False, False
|
351 |
+
text_prompt, image_prompt = "", []
|
352 |
+
img_processor = self.get_vision_tower().image_processor
|
353 |
+
negative_prompt = {}
|
354 |
+
|
355 |
+
for x in inputs:
|
356 |
+
if isinstance(x, str):
|
357 |
+
has_text = True
|
358 |
+
text_prompt += x
|
359 |
+
else:
|
360 |
+
has_image = True
|
361 |
+
text_prompt += DEFAULT_IMAGE_TOKEN
|
362 |
+
image_prompt.append(img_processor.preprocess(x, return_tensors='pt')['pixel_values'])
|
363 |
+
# pdb.set_trace()
|
364 |
+
if len(image_prompt) == 0:
|
365 |
+
image_prompt = None
|
366 |
+
else:
|
367 |
+
image_prompt = torch.cat(image_prompt)
|
368 |
+
image_prompt = image_prompt.type(dtype).to(device)
|
369 |
+
|
370 |
+
if has_image and not has_text:
|
371 |
+
prompt = self.encode_images(image_prompt)
|
372 |
+
# pdb.set_trace()
|
373 |
+
if do_classifier_free_guidance:
|
374 |
+
key = "[NULL_IMAGE]"
|
375 |
+
if key not in negative_prompt:
|
376 |
+
negative_image = torch.zeros_like(image_prompt)
|
377 |
+
negative_prompt[key] = self.encode_images(negative_image)
|
378 |
+
prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
|
379 |
+
else:
|
380 |
+
prompt = self.generate_image(text=[text_prompt], image=image_prompt, tokenizer=tokenizer)
|
381 |
+
if do_classifier_free_guidance:
|
382 |
+
key = ""
|
383 |
+
if key not in negative_prompt:
|
384 |
+
negative_prompt[key] = self.generate_image(text=[""], tokenizer=tokenizer)
|
385 |
+
prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
|
386 |
+
|
387 |
+
gen_pooling = self.get_gen_pooling()
|
388 |
+
n_query = self.get_n_query()
|
389 |
+
num_img, _, c = prompt.shape
|
390 |
+
if 'pool2d' in gen_pooling and has_text and not 'early' in gen_pooling:
|
391 |
+
stride = int(gen_pooling.split('_')[1])
|
392 |
+
sqrt_n = int(n_query**0.5)
|
393 |
+
prompt = prompt.permute(0, 2, 1).reshape(num_img, -1, sqrt_n, sqrt_n)
|
394 |
+
prompt = F.avg_pool2d(prompt, kernel_size=(stride, stride), stride=stride)
|
395 |
+
prompt = prompt.reshape(num_img, c, -1).permute(0,2,1)
|
396 |
+
return prompt
|
397 |
+
|
398 |
+
|
399 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
|
400 |
+
inputs_embeds=None, **kwargs):
|
401 |
+
images = kwargs.pop("images", None)
|
402 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
403 |
+
inputs = super().prepare_inputs_for_generation(
|
404 |
+
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
405 |
+
)
|
406 |
+
if images is not None:
|
407 |
+
inputs['images'] = images
|
408 |
+
if image_sizes is not None:
|
409 |
+
inputs['image_sizes'] = image_sizes
|
410 |
+
return inputs
|
411 |
+
|
412 |
+
AutoConfig.register("blip3o_llama", blip3oConfig)
|
413 |
+
AutoModelForCausalLM.register(blip3oConfig, blip3oLlamaForCausalLM)
|
blip3o/model/language_model/blip3o_qwen.py
ADDED
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple, Union, Dict
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from PIL import Image
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
import transformers
|
9 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
10 |
+
|
11 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
12 |
+
from transformers.generation.utils import GenerateOutput
|
13 |
+
|
14 |
+
from blip3o.model.blip3o_arch import blip3oMetaModel, blip3oMetaForCausalLM
|
15 |
+
|
16 |
+
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLModel, Qwen2_5_VLForConditionalGeneration
|
17 |
+
|
18 |
+
from blip3o.constants import UND_IMAGE_TOKEN_IDX
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
from diffusers.utils.torch_utils import randn_tensor
|
23 |
+
from diffusers.pipelines.pipeline_utils import numpy_to_pil
|
24 |
+
import numpy as np
|
25 |
+
from diffusers.models import AutoencoderKL
|
26 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
27 |
+
|
28 |
+
|
29 |
+
class blip3oQwenConfig(Qwen2_5_VLConfig):
|
30 |
+
model_type = "blip3o_qwen"
|
31 |
+
|
32 |
+
|
33 |
+
class blip3oQwenModel(blip3oMetaModel, Qwen2_5_VLModel):
|
34 |
+
config_class = blip3oQwenConfig
|
35 |
+
|
36 |
+
def __init__(self, config: Qwen2_5_VLConfig):
|
37 |
+
super(blip3oQwenModel, self).__init__(config)
|
38 |
+
|
39 |
+
|
40 |
+
class blip3oQwenForCausalLM(Qwen2_5_VLForConditionalGeneration, blip3oMetaForCausalLM):
|
41 |
+
config_class = blip3oQwenConfig
|
42 |
+
|
43 |
+
def __init__(self, config):
|
44 |
+
Qwen2_5_VLForConditionalGeneration.__init__(self, config)
|
45 |
+
config.model_type = "blip3o_qwen"
|
46 |
+
|
47 |
+
self.model = blip3oQwenModel(config)
|
48 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
49 |
+
# Initialize weights and apply final processing
|
50 |
+
self.post_init()
|
51 |
+
|
52 |
+
def get_model(self):
|
53 |
+
return self.model
|
54 |
+
|
55 |
+
|
56 |
+
def forward(
|
57 |
+
self,
|
58 |
+
input_ids: torch.LongTensor = None,
|
59 |
+
attention_mask: Optional[torch.Tensor] = None,
|
60 |
+
position_ids: Optional[torch.LongTensor] = None,
|
61 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
62 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
63 |
+
labels: Optional[torch.LongTensor] = None,
|
64 |
+
ids: Optional[list] = None,
|
65 |
+
i_s_pos: Optional[list] = None,
|
66 |
+
use_cache: Optional[bool] = None,
|
67 |
+
output_attentions: Optional[bool] = None,
|
68 |
+
output_hidden_states: Optional[bool] = None,
|
69 |
+
gen_image: Optional[torch.FloatTensor] = None,
|
70 |
+
und_image: Optional[torch.FloatTensor] = None,
|
71 |
+
grid_thw: Optional[torch.FloatTensor] = None,
|
72 |
+
image_sizes: Optional[List[List[int]]] = None,
|
73 |
+
return_dict: Optional[bool] = None,
|
74 |
+
cache_position: Optional[torch.LongTensor] = None
|
75 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
76 |
+
|
77 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
78 |
+
output_hidden_states = (
|
79 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
80 |
+
)
|
81 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
82 |
+
|
83 |
+
if inputs_embeds is None:
|
84 |
+
(
|
85 |
+
input_ids,
|
86 |
+
position_ids,
|
87 |
+
attention_mask,
|
88 |
+
past_key_values,
|
89 |
+
inputs_embeds,
|
90 |
+
labels,
|
91 |
+
latents
|
92 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
93 |
+
input_ids,
|
94 |
+
position_ids,
|
95 |
+
attention_mask,
|
96 |
+
past_key_values,
|
97 |
+
labels,
|
98 |
+
gen_image,
|
99 |
+
und_image,
|
100 |
+
grid_thw,
|
101 |
+
i_s_pos,
|
102 |
+
image_sizes
|
103 |
+
)
|
104 |
+
|
105 |
+
outputs = self.model(
|
106 |
+
input_ids=input_ids,
|
107 |
+
attention_mask=attention_mask,
|
108 |
+
position_ids=position_ids,
|
109 |
+
past_key_values=past_key_values,
|
110 |
+
inputs_embeds=inputs_embeds,
|
111 |
+
use_cache=use_cache,
|
112 |
+
output_attentions=output_attentions,
|
113 |
+
output_hidden_states=output_hidden_states,
|
114 |
+
return_dict=return_dict,
|
115 |
+
)
|
116 |
+
|
117 |
+
hidden_states = outputs[0]
|
118 |
+
logits = self.lm_head(hidden_states)
|
119 |
+
logits = logits.float()
|
120 |
+
|
121 |
+
total_loss = None
|
122 |
+
if labels is not None:
|
123 |
+
# Shift so that tokens < n predict n
|
124 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
125 |
+
shift_labels = labels[..., 1:].contiguous()
|
126 |
+
# Flatten the tokens
|
127 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
128 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
129 |
+
shift_labels = shift_labels.view(-1)
|
130 |
+
# Enable model parallelism
|
131 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
132 |
+
loss = loss_fct(shift_logits, shift_labels)
|
133 |
+
|
134 |
+
|
135 |
+
# compute image loss
|
136 |
+
# target_img_embeds = torch.clone(inputs_embeds.detach())[:,1:,:] # get target image emb
|
137 |
+
img_loss_funct = torch.nn.MSELoss()
|
138 |
+
# img_hidden_states = self.get_model().down_projector(hidden_states[:,-self.get_n_query():,:])
|
139 |
+
img_hidden_states = []
|
140 |
+
|
141 |
+
for b in range(hidden_states.shape[0]):
|
142 |
+
img_hidden_states.append(hidden_states[b,i_s_pos[b]:i_s_pos[b]+64,:])
|
143 |
+
img_hidden_states = torch.stack(img_hidden_states,dim=0)
|
144 |
+
img_hidden_states = self.get_model().down_projector(img_hidden_states)
|
145 |
+
# img_loss = 0.0
|
146 |
+
if latents is None:
|
147 |
+
img_loss = img_loss_funct(img_hidden_states, torch.clone(img_hidden_states.detach()))
|
148 |
+
else:
|
149 |
+
bsz = latents.shape[0]
|
150 |
+
# device = latents.device
|
151 |
+
dtype = latents.dtype
|
152 |
+
noise = torch.randn_like(latents, device=latents.device)
|
153 |
+
u = torch.rand(size=(bsz,), device="cpu")
|
154 |
+
indices = (u * self.get_model().noise_scheduler.config.num_train_timesteps).long()
|
155 |
+
timesteps = self.get_model().noise_scheduler.timesteps[indices].to(device=latents.device)
|
156 |
+
sigmas = self.get_sigmas(timesteps, latents.device, n_dim=latents.ndim, dtype=dtype)
|
157 |
+
noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
|
158 |
+
noise_pred = self.get_model().dit(
|
159 |
+
x=noisy_latents,
|
160 |
+
timestep=timesteps,
|
161 |
+
z_latents=self.mask_drop(img_hidden_states),
|
162 |
+
)
|
163 |
+
target = noise - latents
|
164 |
+
img_loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
165 |
+
print(f"img loss {img_loss}")
|
166 |
+
total_loss = img_loss
|
167 |
+
|
168 |
+
return CausalLMOutputWithPast(
|
169 |
+
loss=total_loss,
|
170 |
+
logits=logits,
|
171 |
+
past_key_values=outputs.past_key_values,
|
172 |
+
hidden_states=outputs.hidden_states,
|
173 |
+
attentions=outputs.attentions,
|
174 |
+
)
|
175 |
+
|
176 |
+
|
177 |
+
@torch.no_grad()
|
178 |
+
def generate(
|
179 |
+
self,
|
180 |
+
inputs: Optional[torch.Tensor] = None,
|
181 |
+
images: Optional[torch.Tensor] = None,
|
182 |
+
image_sizes: Optional[torch.Tensor] = None,
|
183 |
+
**kwargs,
|
184 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
185 |
+
position_ids = kwargs.pop("position_ids", None)
|
186 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
187 |
+
if "inputs_embeds" in kwargs:
|
188 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
189 |
+
|
190 |
+
if images is not None:
|
191 |
+
(
|
192 |
+
inputs,
|
193 |
+
position_ids,
|
194 |
+
attention_mask,
|
195 |
+
_,
|
196 |
+
inputs_embeds,
|
197 |
+
img_indicator,
|
198 |
+
_
|
199 |
+
) = self.prepare_inputs_labels_for_understanding(
|
200 |
+
inputs,
|
201 |
+
position_ids,
|
202 |
+
attention_mask,
|
203 |
+
None,
|
204 |
+
None,
|
205 |
+
images,
|
206 |
+
image_sizes=image_sizes
|
207 |
+
)
|
208 |
+
else:
|
209 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
210 |
+
|
211 |
+
return super().generate(
|
212 |
+
position_ids=position_ids,
|
213 |
+
attention_mask=attention_mask,
|
214 |
+
inputs_embeds=inputs_embeds,
|
215 |
+
**kwargs
|
216 |
+
)
|
217 |
+
|
218 |
+
@torch.no_grad()
|
219 |
+
def generate_image(
|
220 |
+
self,
|
221 |
+
text: List[str],
|
222 |
+
tokenizer: AutoTokenizer,
|
223 |
+
pixel_values: Optional[torch.Tensor] = None,
|
224 |
+
image_grid_thw: Optional[torch.Tensor] = None,
|
225 |
+
max_var: Optional[float] = None,
|
226 |
+
# placeholder: str = DEFAULT_IMG_PLACEHOLDER,
|
227 |
+
):
|
228 |
+
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler")
|
229 |
+
|
230 |
+
|
231 |
+
N_QUERY = self.get_n_query()
|
232 |
+
inputs = tokenizer(text, padding="longest", return_tensors="pt")
|
233 |
+
device = self.get_model().device
|
234 |
+
attention_mask = inputs.attention_mask.to(device)
|
235 |
+
input_ids = inputs.input_ids.to(device) # B x N
|
236 |
+
input_ids = torch.cat([input_ids, torch.tensor([[151665]]).to(device)], dim=1)
|
237 |
+
# breakpoint()
|
238 |
+
|
239 |
+
|
240 |
+
text_embeds = self.get_model().embed_tokens(input_ids)
|
241 |
+
latent_queries = self.get_model().latent_queries.repeat(text_embeds.shape[0], 1, 1)
|
242 |
+
|
243 |
+
|
244 |
+
if pixel_values is not None:
|
245 |
+
und_image_idx = (input_ids == UND_IMAGE_TOKEN_IDX)
|
246 |
+
pixel_values = pixel_values.type(self.visual.dtype)
|
247 |
+
und_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
248 |
+
text_embeds[und_image_idx] = und_image_embeds.to(text_embeds.device)[:und_image_idx.sum(), :]
|
249 |
+
|
250 |
+
|
251 |
+
text_embeds = torch.cat([text_embeds, latent_queries], dim=1)
|
252 |
+
attention_mask = torch.cat([attention_mask, torch.ones_like(latent_queries[:, :, 0])], dim=1)
|
253 |
+
|
254 |
+
|
255 |
+
outputs = self.model(
|
256 |
+
inputs_embeds=text_embeds,
|
257 |
+
attention_mask=attention_mask,
|
258 |
+
output_hidden_states=True,
|
259 |
+
return_dict=True,
|
260 |
+
)
|
261 |
+
hidden_states = outputs.hidden_states[-1][:,-N_QUERY:,:]
|
262 |
+
img_hidden_states = hidden_states
|
263 |
+
output_img = self.sample_images(img_hidden_states, scheduler)
|
264 |
+
output_img = output_img.view(1, 1792, -1).permute(0,2,1).contiguous()
|
265 |
+
|
266 |
+
return output_img
|
267 |
+
|
268 |
+
def sample_images(
|
269 |
+
self,
|
270 |
+
img_hidden_states,
|
271 |
+
scheduler,
|
272 |
+
guidance_scale: float = 3.0,
|
273 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
274 |
+
num_inference_steps: int = 30,
|
275 |
+
num_images_per_prompt: int = 1,
|
276 |
+
return_tensor=False,
|
277 |
+
**kwargs,
|
278 |
+
):
|
279 |
+
|
280 |
+
device = img_hidden_states.device
|
281 |
+
dtype = img_hidden_states.dtype
|
282 |
+
|
283 |
+
|
284 |
+
img_hidden_states_null = torch.zeros_like(img_hidden_states, device=device, dtype=dtype)
|
285 |
+
img_hidden_states_input = torch.cat([img_hidden_states_null, img_hidden_states], 0)
|
286 |
+
|
287 |
+
batch_size = img_hidden_states.shape[0]
|
288 |
+
latent_size = self.get_model().dit.config.input_size
|
289 |
+
latent_channels = self.get_model().dit.config.in_channels
|
290 |
+
|
291 |
+
latents = randn_tensor(
|
292 |
+
shape=(batch_size * num_images_per_prompt, latent_channels, latent_size, latent_size),
|
293 |
+
generator=generator,
|
294 |
+
device=device,
|
295 |
+
dtype=dtype,
|
296 |
+
)
|
297 |
+
|
298 |
+
# set step values
|
299 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
300 |
+
scheduler.set_timesteps(num_inference_steps, sigmas=sigmas)
|
301 |
+
|
302 |
+
# Repeat z_latents and conditions for each image per prompt
|
303 |
+
img_hidden_states_input = img_hidden_states_input.repeat_interleave(num_images_per_prompt, dim=0)
|
304 |
+
|
305 |
+
for t in scheduler.timesteps:
|
306 |
+
latent_model_input = latents.repeat(2, 1, 1, 1)
|
307 |
+
if hasattr(scheduler, "scale_model_input"):
|
308 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
309 |
+
|
310 |
+
# predict noise model_output
|
311 |
+
noise_pred = self.get_model().dit(
|
312 |
+
x=latent_model_input,
|
313 |
+
timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]).to(latent_model_input.device, torch.long),
|
314 |
+
z_latents=img_hidden_states_input,
|
315 |
+
)
|
316 |
+
|
317 |
+
# perform guidance
|
318 |
+
noise_pred_uncond, noise_pred = noise_pred.chunk(2)
|
319 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
|
320 |
+
|
321 |
+
# compute previous image: x_t -> x_t-1
|
322 |
+
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
323 |
+
|
324 |
+
# samples = self.decode_latents(latents, return_tensor=return_tensor)
|
325 |
+
# breakpoint()
|
326 |
+
return latents
|
327 |
+
|
328 |
+
def decode_latents(self, latents, normalize=True, return_tensor=False):
|
329 |
+
if isinstance(self.get_model().vae, AutoencoderKL):
|
330 |
+
latents = latents / self.get_model().vae.config.scaling_factor
|
331 |
+
if self.get_model().vae.config.shift_factor is not None:
|
332 |
+
latents = latents + self.get_model().vae.config.shift_factor
|
333 |
+
latents = latents.to(dtype=torch.float32)
|
334 |
+
samples = self.get_model().vae.decode(latents).sample
|
335 |
+
else:
|
336 |
+
samples = self.get_model().vae.decode(latents)
|
337 |
+
if normalize:
|
338 |
+
samples = (samples / 2 + 0.5).clamp(0, 1)
|
339 |
+
else:
|
340 |
+
samples = samples.clamp(-1, 1)
|
341 |
+
if return_tensor:
|
342 |
+
return samples
|
343 |
+
samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
|
344 |
+
samples = numpy_to_pil(samples)
|
345 |
+
return samples
|
346 |
+
|
347 |
+
def prepare_and_encode_inputs(
|
348 |
+
self,
|
349 |
+
inputs: List[str | Image.Image],
|
350 |
+
tokenizer: AutoTokenizer,
|
351 |
+
do_classifier_free_guidance: bool = False,
|
352 |
+
):
|
353 |
+
# pdb.set_trace()
|
354 |
+
device = self.get_model().device
|
355 |
+
dtype = self.get_model().dtype
|
356 |
+
|
357 |
+
has_image, has_text = False, False
|
358 |
+
text_prompt, image_prompt = "", []
|
359 |
+
img_processor = self.get_vision_tower().image_processor
|
360 |
+
negative_prompt = {}
|
361 |
+
|
362 |
+
for x in inputs:
|
363 |
+
if isinstance(x, str):
|
364 |
+
has_text = True
|
365 |
+
text_prompt += x
|
366 |
+
else:
|
367 |
+
has_image = True
|
368 |
+
text_prompt += DEFAULT_IMAGE_TOKEN
|
369 |
+
image_prompt.append(img_processor.preprocess(x, return_tensors='pt')['pixel_values'])
|
370 |
+
# pdb.set_trace()
|
371 |
+
if len(image_prompt) == 0:
|
372 |
+
image_prompt = None
|
373 |
+
else:
|
374 |
+
image_prompt = torch.cat(image_prompt)
|
375 |
+
image_prompt = image_prompt.type(dtype).to(device)
|
376 |
+
|
377 |
+
if has_image and not has_text:
|
378 |
+
prompt = self.encode_images(image_prompt)
|
379 |
+
# pdb.set_trace()
|
380 |
+
if do_classifier_free_guidance:
|
381 |
+
key = "[NULL_IMAGE]"
|
382 |
+
if key not in negative_prompt:
|
383 |
+
negative_image = torch.zeros_like(image_prompt)
|
384 |
+
negative_prompt[key] = self.encode_images(negative_image)
|
385 |
+
prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
|
386 |
+
else:
|
387 |
+
prompt = self.generate_image(text=[text_prompt], image=image_prompt, tokenizer=tokenizer)
|
388 |
+
if do_classifier_free_guidance:
|
389 |
+
key = ""
|
390 |
+
if key not in negative_prompt:
|
391 |
+
negative_prompt[key] = self.generate_image(text=[""], tokenizer=tokenizer)
|
392 |
+
prompt = torch.cat([prompt, negative_prompt[key]], dim=0)
|
393 |
+
|
394 |
+
gen_pooling = self.get_gen_pooling()
|
395 |
+
n_query = self.get_n_query()
|
396 |
+
num_img, _, c = prompt.shape
|
397 |
+
if 'pool2d' in gen_pooling and has_text and not 'early' in gen_pooling:
|
398 |
+
stride = int(gen_pooling.split('_')[1])
|
399 |
+
sqrt_n = int(n_query**0.5)
|
400 |
+
prompt = prompt.permute(0, 2, 1).reshape(num_img, -1, sqrt_n, sqrt_n)
|
401 |
+
prompt = F.avg_pool2d(prompt, kernel_size=(stride, stride), stride=stride)
|
402 |
+
prompt = prompt.reshape(num_img, c, -1).permute(0,2,1)
|
403 |
+
return prompt
|
404 |
+
|
405 |
+
|
406 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
|
407 |
+
inputs_embeds=None, **kwargs):
|
408 |
+
images = kwargs.pop("images", None)
|
409 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
410 |
+
inputs = super().prepare_inputs_for_generation(
|
411 |
+
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
412 |
+
)
|
413 |
+
if images is not None:
|
414 |
+
inputs['images'] = images
|
415 |
+
if image_sizes is not None:
|
416 |
+
inputs['image_sizes'] = image_sizes
|
417 |
+
return inputs
|
418 |
+
|
419 |
+
AutoConfig.register("blip3o_qwen", blip3oQwenConfig)
|
420 |
+
AutoModelForCausalLM.register(blip3oQwenConfig, blip3oQwenForCausalLM)
|
blip3o/model/lumina_nextdit2d.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import Any, Dict, Optional
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
20 |
+
from diffusers.models.attention import LuminaFeedForward
|
21 |
+
from diffusers.models.attention_processor import Attention, LuminaAttnProcessor2_0
|
22 |
+
from diffusers.models.embeddings import LuminaCombinedTimestepCaptionEmbedding, LuminaPatchEmbed, PixArtAlphaTextProjection
|
23 |
+
|
24 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
25 |
+
from diffusers.models.modeling_utils import ModelMixin
|
26 |
+
from diffusers.models.normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
|
27 |
+
from diffusers.utils import is_torch_version, logging
|
28 |
+
|
29 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
30 |
+
|
31 |
+
|
32 |
+
class LuminaNextDiTBlock(nn.Module):
|
33 |
+
"""
|
34 |
+
A LuminaNextDiTBlock for LuminaNextDiT2DModel.
|
35 |
+
|
36 |
+
Parameters:
|
37 |
+
dim (`int`): Embedding dimension of the input features.
|
38 |
+
num_attention_heads (`int`): Number of attention heads.
|
39 |
+
num_kv_heads (`int`):
|
40 |
+
Number of attention heads in key and value features (if using GQA), or set to None for the same as query.
|
41 |
+
multiple_of (`int`): The number of multiple of ffn layer.
|
42 |
+
ffn_dim_multiplier (`float`): The multipier factor of ffn layer dimension.
|
43 |
+
norm_eps (`float`): The eps for norm layer.
|
44 |
+
qk_norm (`bool`): normalization for query and key.
|
45 |
+
cross_attention_dim (`int`): Cross attention embedding dimension of the input text prompt hidden_states.
|
46 |
+
norm_elementwise_affine (`bool`, *optional*, defaults to True),
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
dim: int,
|
52 |
+
num_attention_heads: int,
|
53 |
+
num_kv_heads: int,
|
54 |
+
multiple_of: int,
|
55 |
+
ffn_dim_multiplier: float,
|
56 |
+
norm_eps: float,
|
57 |
+
qk_norm: bool,
|
58 |
+
cross_attention_dim: int,
|
59 |
+
norm_elementwise_affine: bool = True,
|
60 |
+
) -> None:
|
61 |
+
super().__init__()
|
62 |
+
self.head_dim = dim // num_attention_heads
|
63 |
+
|
64 |
+
self.gate = nn.Parameter(torch.zeros([num_attention_heads]))
|
65 |
+
|
66 |
+
# Self-attention
|
67 |
+
self.attn1 = Attention(
|
68 |
+
query_dim=dim,
|
69 |
+
cross_attention_dim=None,
|
70 |
+
dim_head=dim // num_attention_heads,
|
71 |
+
qk_norm="layer_norm_across_heads" if qk_norm else None,
|
72 |
+
heads=num_attention_heads,
|
73 |
+
kv_heads=num_kv_heads,
|
74 |
+
eps=1e-5,
|
75 |
+
bias=False,
|
76 |
+
out_bias=False,
|
77 |
+
processor=LuminaAttnProcessor2_0(),
|
78 |
+
)
|
79 |
+
self.attn1.to_out = nn.Identity()
|
80 |
+
|
81 |
+
# Cross-attention
|
82 |
+
self.attn2 = Attention(
|
83 |
+
query_dim=dim,
|
84 |
+
cross_attention_dim=cross_attention_dim,
|
85 |
+
dim_head=dim // num_attention_heads,
|
86 |
+
qk_norm="layer_norm_across_heads" if qk_norm else None,
|
87 |
+
heads=num_attention_heads,
|
88 |
+
kv_heads=num_kv_heads,
|
89 |
+
eps=1e-5,
|
90 |
+
bias=False,
|
91 |
+
out_bias=False,
|
92 |
+
processor=LuminaAttnProcessor2_0(),
|
93 |
+
)
|
94 |
+
|
95 |
+
self.feed_forward = LuminaFeedForward(
|
96 |
+
dim=dim,
|
97 |
+
inner_dim=4 * dim,
|
98 |
+
multiple_of=multiple_of,
|
99 |
+
ffn_dim_multiplier=ffn_dim_multiplier,
|
100 |
+
)
|
101 |
+
|
102 |
+
self.norm1 = LuminaRMSNormZero(
|
103 |
+
embedding_dim=dim,
|
104 |
+
norm_eps=norm_eps,
|
105 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
106 |
+
)
|
107 |
+
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
108 |
+
|
109 |
+
self.norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
110 |
+
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
111 |
+
|
112 |
+
self.norm1_context = RMSNorm(cross_attention_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
113 |
+
|
114 |
+
def forward(
|
115 |
+
self,
|
116 |
+
hidden_states: torch.Tensor,
|
117 |
+
attention_mask: torch.Tensor,
|
118 |
+
image_rotary_emb: torch.Tensor,
|
119 |
+
encoder_hidden_states: torch.Tensor,
|
120 |
+
encoder_mask: torch.Tensor,
|
121 |
+
temb: torch.Tensor,
|
122 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
123 |
+
):
|
124 |
+
"""
|
125 |
+
Perform a forward pass through the LuminaNextDiTBlock.
|
126 |
+
|
127 |
+
Parameters:
|
128 |
+
hidden_states (`torch.Tensor`): The input of hidden_states for LuminaNextDiTBlock.
|
129 |
+
attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask.
|
130 |
+
image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies.
|
131 |
+
encoder_hidden_states: (`torch.Tensor`): The hidden_states of text prompt are processed by Gemma encoder.
|
132 |
+
encoder_mask (`torch.Tensor`): The hidden_states of text prompt attention mask.
|
133 |
+
temb (`torch.Tensor`): Timestep embedding with text prompt embedding.
|
134 |
+
cross_attention_kwargs (`Dict[str, Any]`): kwargs for cross attention.
|
135 |
+
"""
|
136 |
+
residual = hidden_states
|
137 |
+
|
138 |
+
# Self-attention
|
139 |
+
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
140 |
+
self_attn_output = self.attn1(
|
141 |
+
hidden_states=norm_hidden_states,
|
142 |
+
encoder_hidden_states=norm_hidden_states,
|
143 |
+
attention_mask=attention_mask,
|
144 |
+
query_rotary_emb=image_rotary_emb,
|
145 |
+
key_rotary_emb=image_rotary_emb,
|
146 |
+
**cross_attention_kwargs,
|
147 |
+
)
|
148 |
+
|
149 |
+
# Cross-attention
|
150 |
+
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
|
151 |
+
cross_attn_output = self.attn2(
|
152 |
+
hidden_states=norm_hidden_states,
|
153 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
154 |
+
attention_mask=encoder_mask,
|
155 |
+
query_rotary_emb=image_rotary_emb,
|
156 |
+
key_rotary_emb=None,
|
157 |
+
**cross_attention_kwargs,
|
158 |
+
)
|
159 |
+
cross_attn_output = cross_attn_output * self.gate.tanh().view(1, 1, -1, 1)
|
160 |
+
mixed_attn_output = self_attn_output + cross_attn_output
|
161 |
+
mixed_attn_output = mixed_attn_output.flatten(-2)
|
162 |
+
# linear proj
|
163 |
+
hidden_states = self.attn2.to_out[0](mixed_attn_output)
|
164 |
+
|
165 |
+
hidden_states = residual + gate_msa.unsqueeze(1).tanh() * self.norm2(hidden_states)
|
166 |
+
|
167 |
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
168 |
+
|
169 |
+
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
170 |
+
|
171 |
+
return hidden_states
|
172 |
+
|
173 |
+
|
174 |
+
class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
|
175 |
+
"""
|
176 |
+
LuminaNextDiT: Diffusion model with a Transformer backbone.
|
177 |
+
|
178 |
+
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
|
179 |
+
|
180 |
+
Parameters:
|
181 |
+
sample_size (`int`): The width of the latent images. This is fixed during training since
|
182 |
+
it is used to learn a number of position embeddings.
|
183 |
+
patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
|
184 |
+
The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
|
185 |
+
in_channels (`int`, *optional*, defaults to 4):
|
186 |
+
The number of input channels for the model. Typically, this matches the number of channels in the input
|
187 |
+
images.
|
188 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
189 |
+
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
|
190 |
+
hidden representations.
|
191 |
+
num_layers (`int`, *optional*, default to 32):
|
192 |
+
The number of layers in the model. This defines the depth of the neural network.
|
193 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
194 |
+
The number of attention heads in each attention layer. This parameter specifies how many separate attention
|
195 |
+
mechanisms are used.
|
196 |
+
num_kv_heads (`int`, *optional*, defaults to 8):
|
197 |
+
The number of key-value heads in the attention mechanism, if different from the number of attention heads.
|
198 |
+
If None, it defaults to num_attention_heads.
|
199 |
+
multiple_of (`int`, *optional*, defaults to 256):
|
200 |
+
A factor that the hidden size should be a multiple of. This can help optimize certain hardware
|
201 |
+
configurations.
|
202 |
+
ffn_dim_multiplier (`float`, *optional*):
|
203 |
+
A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
|
204 |
+
the model configuration.
|
205 |
+
norm_eps (`float`, *optional*, defaults to 1e-5):
|
206 |
+
A small value added to the denominator for numerical stability in normalization layers.
|
207 |
+
learn_sigma (`bool`, *optional*, defaults to True):
|
208 |
+
Whether the model should learn the sigma parameter, which might be related to uncertainty or variance in
|
209 |
+
predictions.
|
210 |
+
qk_norm (`bool`, *optional*, defaults to True):
|
211 |
+
Indicates if the queries and keys in the attention mechanism should be normalized.
|
212 |
+
cross_attention_dim (`int`, *optional*, defaults to 2048):
|
213 |
+
The dimensionality of the text embeddings. This parameter defines the size of the text representations used
|
214 |
+
in the model.
|
215 |
+
scaling_factor (`float`, *optional*, defaults to 1.0):
|
216 |
+
A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
|
217 |
+
overall scale of the model's operations.
|
218 |
+
"""
|
219 |
+
|
220 |
+
_supports_gradient_checkpointing = True
|
221 |
+
_no_split_modules = ["LuminaNextDiTBlock"]
|
222 |
+
|
223 |
+
@register_to_config
|
224 |
+
def __init__(
|
225 |
+
self,
|
226 |
+
sample_size: int = 128,
|
227 |
+
patch_size: Optional[int] = 2,
|
228 |
+
in_channels: Optional[int] = 4,
|
229 |
+
hidden_size: Optional[int] = 2304,
|
230 |
+
num_layers: Optional[int] = 32, # 32
|
231 |
+
num_attention_heads: Optional[int] = 32, # 32
|
232 |
+
num_kv_heads: Optional[int] = None,
|
233 |
+
multiple_of: Optional[int] = 256,
|
234 |
+
ffn_dim_multiplier: Optional[float] = None,
|
235 |
+
norm_eps: Optional[float] = 1e-5,
|
236 |
+
learn_sigma: Optional[bool] = True,
|
237 |
+
qk_norm: Optional[bool] = True,
|
238 |
+
cross_attention_dim: Optional[int] = 2048,
|
239 |
+
scaling_factor: Optional[float] = 1.0,
|
240 |
+
) -> None:
|
241 |
+
super().__init__()
|
242 |
+
self.sample_size = sample_size
|
243 |
+
self.patch_size = patch_size
|
244 |
+
self.in_channels = in_channels
|
245 |
+
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
246 |
+
self.hidden_size = hidden_size
|
247 |
+
self.num_attention_heads = num_attention_heads
|
248 |
+
self.head_dim = hidden_size // num_attention_heads
|
249 |
+
self.scaling_factor = scaling_factor
|
250 |
+
self.gradient_checkpointing = False
|
251 |
+
|
252 |
+
self.caption_projection = PixArtAlphaTextProjection(in_features=cross_attention_dim, hidden_size=hidden_size)
|
253 |
+
self.patch_embedder = LuminaPatchEmbed(patch_size=patch_size, in_channels=in_channels, embed_dim=hidden_size, bias=True)
|
254 |
+
|
255 |
+
self.time_caption_embed = LuminaCombinedTimestepCaptionEmbedding(hidden_size=min(hidden_size, 1024), cross_attention_dim=hidden_size)
|
256 |
+
|
257 |
+
self.layers = nn.ModuleList(
|
258 |
+
[
|
259 |
+
LuminaNextDiTBlock(
|
260 |
+
hidden_size,
|
261 |
+
num_attention_heads,
|
262 |
+
num_kv_heads,
|
263 |
+
multiple_of,
|
264 |
+
ffn_dim_multiplier,
|
265 |
+
norm_eps,
|
266 |
+
qk_norm,
|
267 |
+
hidden_size,
|
268 |
+
)
|
269 |
+
for _ in range(num_layers)
|
270 |
+
]
|
271 |
+
)
|
272 |
+
self.norm_out = LuminaLayerNormContinuous(
|
273 |
+
embedding_dim=hidden_size,
|
274 |
+
conditioning_embedding_dim=min(hidden_size, 1024),
|
275 |
+
elementwise_affine=False,
|
276 |
+
eps=1e-6,
|
277 |
+
bias=True,
|
278 |
+
out_dim=patch_size * patch_size * self.out_channels,
|
279 |
+
)
|
280 |
+
# self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels)
|
281 |
+
|
282 |
+
assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"
|
283 |
+
|
284 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
285 |
+
if hasattr(module, "gradient_checkpointing"):
|
286 |
+
module.gradient_checkpointing = value
|
287 |
+
|
288 |
+
def forward(
|
289 |
+
self,
|
290 |
+
hidden_states: torch.Tensor,
|
291 |
+
timestep: torch.Tensor,
|
292 |
+
encoder_hidden_states: torch.Tensor,
|
293 |
+
encoder_mask: torch.Tensor,
|
294 |
+
image_rotary_emb: torch.Tensor,
|
295 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
296 |
+
return_dict=True,
|
297 |
+
) -> torch.Tensor:
|
298 |
+
"""
|
299 |
+
Forward pass of LuminaNextDiT.
|
300 |
+
|
301 |
+
Parameters:
|
302 |
+
hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W).
|
303 |
+
timestep (torch.Tensor): Tensor of diffusion timesteps of shape (N,).
|
304 |
+
encoder_hidden_states (torch.Tensor): Tensor of caption features of shape (N, D).
|
305 |
+
encoder_mask (torch.Tensor): Tensor of caption masks of shape (N, L).
|
306 |
+
"""
|
307 |
+
hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb)
|
308 |
+
image_rotary_emb = image_rotary_emb.to(hidden_states.device)
|
309 |
+
# breakpoint()
|
310 |
+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
311 |
+
temb = self.time_caption_embed(timestep, encoder_hidden_states, encoder_mask)
|
312 |
+
|
313 |
+
encoder_mask = encoder_mask.bool()
|
314 |
+
|
315 |
+
for layer in self.layers:
|
316 |
+
if self.training and self.gradient_checkpointing:
|
317 |
+
|
318 |
+
def create_custom_forward(module, return_dict=None):
|
319 |
+
def custom_forward(*inputs):
|
320 |
+
if return_dict is not None:
|
321 |
+
return module(*inputs, return_dict=return_dict)
|
322 |
+
else:
|
323 |
+
return module(*inputs)
|
324 |
+
|
325 |
+
return custom_forward
|
326 |
+
|
327 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
328 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
329 |
+
create_custom_forward(layer),
|
330 |
+
hidden_states,
|
331 |
+
mask,
|
332 |
+
image_rotary_emb,
|
333 |
+
encoder_hidden_states,
|
334 |
+
encoder_mask,
|
335 |
+
temb,
|
336 |
+
cross_attention_kwargs,
|
337 |
+
**ckpt_kwargs,
|
338 |
+
)
|
339 |
+
else:
|
340 |
+
hidden_states = layer(
|
341 |
+
hidden_states,
|
342 |
+
mask,
|
343 |
+
image_rotary_emb,
|
344 |
+
encoder_hidden_states,
|
345 |
+
encoder_mask,
|
346 |
+
temb=temb,
|
347 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
348 |
+
)
|
349 |
+
|
350 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
351 |
+
|
352 |
+
# unpatchify
|
353 |
+
height_tokens = width_tokens = self.patch_size
|
354 |
+
height, width = img_size[0]
|
355 |
+
batch_size = hidden_states.size(0)
|
356 |
+
sequence_length = (height // height_tokens) * (width // width_tokens)
|
357 |
+
hidden_states = hidden_states[:, :sequence_length].view(
|
358 |
+
batch_size, height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels
|
359 |
+
)
|
360 |
+
output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
|
361 |
+
|
362 |
+
if not return_dict:
|
363 |
+
return (output,)
|
364 |
+
|
365 |
+
return Transformer2DModelOutput(sample=output)
|
blip3o/model/make_delta.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from tqdm import tqdm
|
5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
6 |
+
from blip3o.model.utils import auto_upgrade
|
7 |
+
|
8 |
+
|
9 |
+
def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
|
10 |
+
print("Loading base model")
|
11 |
+
base = AutoModelForCausalLM.from_pretrained(
|
12 |
+
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
13 |
+
|
14 |
+
print("Loading target model")
|
15 |
+
auto_upgrade(target_model_path)
|
16 |
+
target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
17 |
+
|
18 |
+
print("Calculating delta")
|
19 |
+
for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
|
20 |
+
if name not in base.state_dict():
|
21 |
+
assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
|
22 |
+
continue
|
23 |
+
if param.data.shape == base.state_dict()[name].shape:
|
24 |
+
param.data -= base.state_dict()[name]
|
25 |
+
else:
|
26 |
+
assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
|
27 |
+
bparam = base.state_dict()[name]
|
28 |
+
param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
|
29 |
+
|
30 |
+
print("Saving delta")
|
31 |
+
if hub_repo_id:
|
32 |
+
kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
|
33 |
+
else:
|
34 |
+
kwargs = {}
|
35 |
+
target.save_pretrained(delta_path, **kwargs)
|
36 |
+
target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
|
37 |
+
target_tokenizer.save_pretrained(delta_path, **kwargs)
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
parser = argparse.ArgumentParser()
|
42 |
+
parser.add_argument("--base-model-path", type=str, required=True)
|
43 |
+
parser.add_argument("--target-model-path", type=str, required=True)
|
44 |
+
parser.add_argument("--delta-path", type=str, required=True)
|
45 |
+
parser.add_argument("--hub-repo-id", type=str, default=None)
|
46 |
+
args = parser.parse_args()
|
47 |
+
|
48 |
+
make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
|
blip3o/model/multimodal_encoder/builder.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from .clip_encoder import CLIPVisionTower
|
3 |
+
from .imagebind import ImageBindWrapper
|
4 |
+
from .open_clip_encoder import OpenCLIPVisionTower
|
5 |
+
from .siglip_encoder import SigLipVisionTower
|
6 |
+
from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
|
7 |
+
|
8 |
+
from .eva_clip.eva_clip_encoder import EvaClipVisionTower
|
9 |
+
from .dev_eva_clip.eva_vit import EvaViTWrapper
|
10 |
+
|
11 |
+
from blip3o.model.nextdit_crossattn import NextDiTCrossAttnConfig, NextDiTCrossAttn
|
12 |
+
from diffusers.models import AutoencoderKL
|
13 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
14 |
+
|
15 |
+
|
16 |
+
def build_vision_tower(vision_tower_cfg, **kwargs):
|
17 |
+
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
|
18 |
+
is_absolute_path_exists = os.path.exists(vision_tower)
|
19 |
+
use_s2 = getattr(vision_tower_cfg, 's2', False)
|
20 |
+
if "siglip" in vision_tower:
|
21 |
+
return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
|
22 |
+
if "eva" in vision_tower:
|
23 |
+
return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
24 |
+
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
|
25 |
+
if use_s2:
|
26 |
+
return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
|
27 |
+
else:
|
28 |
+
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
29 |
+
|
30 |
+
raise ValueError(f'Unknown vision tower: {vision_tower}')
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
def build_gen_vision_tower(vision_tower_cfg, **kwargs):
|
36 |
+
vision_tower = getattr(vision_tower_cfg, 'gen_vision_tower')
|
37 |
+
is_absolute_path_exists = os.path.exists(vision_tower)
|
38 |
+
use_s2 = getattr(vision_tower_cfg, 's2', False)
|
39 |
+
if "siglip" in vision_tower:
|
40 |
+
return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
|
41 |
+
if "eva" in vision_tower:
|
42 |
+
return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
43 |
+
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
|
44 |
+
if use_s2:
|
45 |
+
return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
|
46 |
+
else:
|
47 |
+
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
48 |
+
|
49 |
+
raise ValueError(f'Unknown vision tower: {vision_tower}')
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
def build_dit(vision_tower_cfg, **kwargs):
|
54 |
+
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae")
|
55 |
+
# vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
|
56 |
+
dit = NextDiTCrossAttn(NextDiTCrossAttnConfig())
|
57 |
+
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler")
|
58 |
+
# scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler")
|
59 |
+
vae.eval()
|
60 |
+
vae.requires_grad_(False)
|
61 |
+
return dit, vae, noise_scheduler
|
62 |
+
|
63 |
+
|
blip3o/model/multimodal_encoder/clip_encoder.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
4 |
+
|
5 |
+
try:
|
6 |
+
from s2wrapper import forward as multiscale_forward
|
7 |
+
except:
|
8 |
+
pass
|
9 |
+
|
10 |
+
|
11 |
+
class CLIPVisionTower(nn.Module):
|
12 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
self.is_loaded = False
|
16 |
+
|
17 |
+
self.vision_tower_name = vision_tower
|
18 |
+
self.select_layer = args.mm_vision_select_layer
|
19 |
+
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
|
20 |
+
|
21 |
+
if not delay_load:
|
22 |
+
print(f"Loading vision tower: {vision_tower}")
|
23 |
+
self.load_model()
|
24 |
+
elif getattr(args, "unfreeze_mm_vision_tower", False):
|
25 |
+
# TODO: better detector is needed.
|
26 |
+
print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
|
27 |
+
self.load_model()
|
28 |
+
elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
|
29 |
+
print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
|
30 |
+
self.load_model()
|
31 |
+
else:
|
32 |
+
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
33 |
+
|
34 |
+
def load_model(self, device_map=None):
|
35 |
+
if self.is_loaded:
|
36 |
+
print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
|
37 |
+
return
|
38 |
+
|
39 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
40 |
+
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
|
41 |
+
self.vision_tower.requires_grad_(False)
|
42 |
+
|
43 |
+
self.is_loaded = True
|
44 |
+
|
45 |
+
def feature_select(self, image_forward_outs):
|
46 |
+
select_feature_type = self.select_feature
|
47 |
+
|
48 |
+
if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
|
49 |
+
select_every_k_layer = len(image_forward_outs.hidden_states) // 4
|
50 |
+
image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1)
|
51 |
+
select_feature_type = select_feature_type.replace("slicefour_", "")
|
52 |
+
elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]:
|
53 |
+
select_layers = [-2, -5, -8, -11, 6]
|
54 |
+
image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1)
|
55 |
+
select_feature_type = select_feature_type.replace("slice_m25811_f6_", "")
|
56 |
+
else:
|
57 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
58 |
+
|
59 |
+
if select_feature_type == "patch":
|
60 |
+
image_features = image_features[:, 1:]
|
61 |
+
elif select_feature_type == "cls_patch":
|
62 |
+
image_features = image_features
|
63 |
+
else:
|
64 |
+
raise ValueError(f"Unexpected select feature: {select_feature_type}")
|
65 |
+
return image_features
|
66 |
+
|
67 |
+
def forward(self, images):
|
68 |
+
if type(images) is list:
|
69 |
+
image_features = []
|
70 |
+
for image in images:
|
71 |
+
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
72 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
73 |
+
image_features.append(image_feature)
|
74 |
+
else:
|
75 |
+
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
76 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
77 |
+
|
78 |
+
return image_features
|
79 |
+
|
80 |
+
@property
|
81 |
+
def dummy_feature(self):
|
82 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
83 |
+
|
84 |
+
@property
|
85 |
+
def dtype(self):
|
86 |
+
return self.vision_tower.dtype
|
87 |
+
|
88 |
+
@property
|
89 |
+
def device(self):
|
90 |
+
return self.vision_tower.device
|
91 |
+
|
92 |
+
@property
|
93 |
+
def config(self):
|
94 |
+
if self.is_loaded:
|
95 |
+
return self.vision_tower.config
|
96 |
+
else:
|
97 |
+
return self.cfg_only
|
98 |
+
|
99 |
+
@property
|
100 |
+
def hidden_size(self):
|
101 |
+
_hidden_size = self.config.hidden_size
|
102 |
+
if "slicefour" in self.select_feature:
|
103 |
+
_hidden_size *= 4
|
104 |
+
if "slice_m25811_f6" in self.select_feature:
|
105 |
+
_hidden_size *= 5
|
106 |
+
return _hidden_size
|
107 |
+
|
108 |
+
@property
|
109 |
+
def num_patches_per_side(self):
|
110 |
+
return self.config.image_size // self.config.patch_size
|
111 |
+
|
112 |
+
@property
|
113 |
+
def num_patches(self):
|
114 |
+
_num_patches = (self.config.image_size // self.config.patch_size) ** 2
|
115 |
+
if "cls_patch" in self.select_feature:
|
116 |
+
_num_patches += 1
|
117 |
+
return _num_patches
|
118 |
+
|
119 |
+
@property
|
120 |
+
def image_size(self):
|
121 |
+
return self.config.image_size
|
122 |
+
|
123 |
+
|
124 |
+
class CLIPVisionTowerS2(CLIPVisionTower):
|
125 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
126 |
+
|
127 |
+
self.s2_scales = getattr(args, "s2_scales", "336,672,1008")
|
128 |
+
self.s2_scales = list(map(int, self.s2_scales.split(",")))
|
129 |
+
self.s2_scales.sort()
|
130 |
+
self.s2_split_size = self.s2_scales[0]
|
131 |
+
self.s2_image_size = self.s2_scales[-1]
|
132 |
+
|
133 |
+
super().__init__(vision_tower, args, delay_load)
|
134 |
+
|
135 |
+
# change resize/crop size in preprocessing to the largest image size in s2_scale
|
136 |
+
if not delay_load or getattr(args, "unfreeze_mm_vision_tower", False):
|
137 |
+
self.image_processor.size["shortest_edge"] = self.s2_image_size
|
138 |
+
self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size
|
139 |
+
|
140 |
+
def load_model(self, device_map=None):
|
141 |
+
if self.is_loaded:
|
142 |
+
print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
|
143 |
+
return
|
144 |
+
|
145 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
146 |
+
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
|
147 |
+
self.vision_tower.requires_grad_(False)
|
148 |
+
|
149 |
+
self.image_processor.size["shortest_edge"] = self.s2_image_size
|
150 |
+
self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size
|
151 |
+
|
152 |
+
self.is_loaded = True
|
153 |
+
|
154 |
+
def forward_feature(self, images):
|
155 |
+
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
156 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
157 |
+
return image_features
|
158 |
+
|
159 |
+
def forward(self, images):
|
160 |
+
if type(images) is list:
|
161 |
+
image_features = []
|
162 |
+
for image in images:
|
163 |
+
image_feature = multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True)
|
164 |
+
image_features.append(image_feature)
|
165 |
+
else:
|
166 |
+
image_features = multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True)
|
167 |
+
|
168 |
+
return image_features
|
169 |
+
|
170 |
+
@property
|
171 |
+
def hidden_size(self):
|
172 |
+
return self.config.hidden_size * len(self.s2_scales)
|
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
2 |
+
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer
|
3 |
+
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
|
4 |
+
from .loss import ClipLoss
|
5 |
+
from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
|
6 |
+
from .openai import load_openai_model, list_openai_models
|
7 |
+
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
|
8 |
+
from .tokenizer import SimpleTokenizer, tokenize
|
9 |
+
from .transform import image_transform
|
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/constants.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
2 |
+
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/eva_vit_model.py
ADDED
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Adapted from https://github.com/microsoft/unilm/tree/master/beit
|
3 |
+
# --------------------------------------------------------
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
try:
|
11 |
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
12 |
+
except:
|
13 |
+
from timm.layers import drop_path, to_2tuple, trunc_normal_
|
14 |
+
|
15 |
+
from .transformer import PatchDropout
|
16 |
+
from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
|
17 |
+
|
18 |
+
if os.getenv("ENV_TYPE") == "deepspeed":
|
19 |
+
try:
|
20 |
+
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
|
21 |
+
except:
|
22 |
+
from torch.utils.checkpoint import checkpoint
|
23 |
+
else:
|
24 |
+
from torch.utils.checkpoint import checkpoint
|
25 |
+
|
26 |
+
try:
|
27 |
+
import xformers.ops as xops
|
28 |
+
except ImportError:
|
29 |
+
xops = None
|
30 |
+
# print("Please 'pip install xformers'")
|
31 |
+
|
32 |
+
|
33 |
+
class DropPath(nn.Module):
|
34 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
35 |
+
|
36 |
+
def __init__(self, drop_prob=None):
|
37 |
+
super(DropPath, self).__init__()
|
38 |
+
self.drop_prob = drop_prob
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
return drop_path(x, self.drop_prob, self.training)
|
42 |
+
|
43 |
+
def extra_repr(self) -> str:
|
44 |
+
return "p={}".format(self.drop_prob)
|
45 |
+
|
46 |
+
|
47 |
+
class Mlp(nn.Module):
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
in_features,
|
51 |
+
hidden_features=None,
|
52 |
+
out_features=None,
|
53 |
+
act_layer=nn.GELU,
|
54 |
+
norm_layer=nn.LayerNorm,
|
55 |
+
drop=0.0,
|
56 |
+
subln=False,
|
57 |
+
):
|
58 |
+
super().__init__()
|
59 |
+
out_features = out_features or in_features
|
60 |
+
hidden_features = hidden_features or in_features
|
61 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
62 |
+
self.act = act_layer()
|
63 |
+
|
64 |
+
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
|
65 |
+
|
66 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
67 |
+
self.drop = nn.Dropout(drop)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
x = self.fc1(x)
|
71 |
+
x = self.act(x)
|
72 |
+
# x = self.drop(x)
|
73 |
+
# commit this for the orignal BERT implement
|
74 |
+
x = self.ffn_ln(x)
|
75 |
+
|
76 |
+
x = self.fc2(x)
|
77 |
+
x = self.drop(x)
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
class SwiGLU(nn.Module):
|
82 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.0, norm_layer=nn.LayerNorm, subln=False):
|
83 |
+
super().__init__()
|
84 |
+
out_features = out_features or in_features
|
85 |
+
hidden_features = hidden_features or in_features
|
86 |
+
|
87 |
+
self.w1 = nn.Linear(in_features, hidden_features)
|
88 |
+
self.w2 = nn.Linear(in_features, hidden_features)
|
89 |
+
|
90 |
+
self.act = act_layer()
|
91 |
+
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
|
92 |
+
self.w3 = nn.Linear(hidden_features, out_features)
|
93 |
+
|
94 |
+
self.drop = nn.Dropout(drop)
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
x1 = self.w1(x)
|
98 |
+
x2 = self.w2(x)
|
99 |
+
hidden = self.act(x1) * x2
|
100 |
+
x = self.ffn_ln(hidden)
|
101 |
+
x = self.w3(x)
|
102 |
+
x = self.drop(x)
|
103 |
+
return x
|
104 |
+
|
105 |
+
|
106 |
+
class Attention(nn.Module):
|
107 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
|
108 |
+
super().__init__()
|
109 |
+
self.num_heads = num_heads
|
110 |
+
head_dim = dim // num_heads
|
111 |
+
if attn_head_dim is not None:
|
112 |
+
head_dim = attn_head_dim
|
113 |
+
all_head_dim = head_dim * self.num_heads
|
114 |
+
self.scale = qk_scale or head_dim**-0.5
|
115 |
+
|
116 |
+
self.subln = subln
|
117 |
+
if self.subln:
|
118 |
+
self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
|
119 |
+
self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
|
120 |
+
self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
|
121 |
+
else:
|
122 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
123 |
+
|
124 |
+
if qkv_bias:
|
125 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
126 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
127 |
+
else:
|
128 |
+
self.q_bias = None
|
129 |
+
self.v_bias = None
|
130 |
+
|
131 |
+
if window_size:
|
132 |
+
self.window_size = window_size
|
133 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
134 |
+
self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
135 |
+
# cls to token & token 2 cls & cls to cls
|
136 |
+
|
137 |
+
# get pair-wise relative position index for each token inside the window
|
138 |
+
coords_h = torch.arange(window_size[0])
|
139 |
+
coords_w = torch.arange(window_size[1])
|
140 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
141 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
142 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
143 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
144 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
145 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
146 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
147 |
+
relative_position_index = torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
148 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
149 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
150 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
151 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
152 |
+
|
153 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
154 |
+
else:
|
155 |
+
self.window_size = None
|
156 |
+
self.relative_position_bias_table = None
|
157 |
+
self.relative_position_index = None
|
158 |
+
|
159 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
160 |
+
self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
|
161 |
+
# self.proj = nn.Linear(all_head_dim, all_head_dim)
|
162 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
163 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
164 |
+
self.xattn = xattn
|
165 |
+
self.xattn_drop = attn_drop
|
166 |
+
|
167 |
+
self.rope = rope
|
168 |
+
|
169 |
+
def forward(self, x, rel_pos_bias=None, attn_mask=None):
|
170 |
+
B, N, C = x.shape
|
171 |
+
if self.subln:
|
172 |
+
q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
|
173 |
+
k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
|
174 |
+
v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
|
175 |
+
|
176 |
+
q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
|
177 |
+
k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
178 |
+
v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
179 |
+
else:
|
180 |
+
|
181 |
+
qkv_bias = None
|
182 |
+
if self.q_bias is not None:
|
183 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
184 |
+
|
185 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
186 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
|
187 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
188 |
+
|
189 |
+
if self.rope:
|
190 |
+
# slightly fast impl
|
191 |
+
q_t = q[:, :, 1:, :]
|
192 |
+
ro_q_t = self.rope(q_t)
|
193 |
+
q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
|
194 |
+
|
195 |
+
k_t = k[:, :, 1:, :]
|
196 |
+
ro_k_t = self.rope(k_t)
|
197 |
+
k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
|
198 |
+
|
199 |
+
if self.xattn:
|
200 |
+
q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
|
201 |
+
k = k.permute(0, 2, 1, 3)
|
202 |
+
v = v.permute(0, 2, 1, 3)
|
203 |
+
|
204 |
+
x = xops.memory_efficient_attention(
|
205 |
+
q,
|
206 |
+
k,
|
207 |
+
v,
|
208 |
+
p=self.xattn_drop,
|
209 |
+
scale=self.scale,
|
210 |
+
)
|
211 |
+
x = x.reshape(B, N, -1)
|
212 |
+
x = self.inner_attn_ln(x)
|
213 |
+
x = self.proj(x)
|
214 |
+
x = self.proj_drop(x)
|
215 |
+
else:
|
216 |
+
q = q * self.scale
|
217 |
+
attn = q @ k.transpose(-2, -1)
|
218 |
+
|
219 |
+
if self.relative_position_bias_table is not None:
|
220 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
221 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
222 |
+
attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
|
223 |
+
|
224 |
+
if rel_pos_bias is not None:
|
225 |
+
attn = attn + rel_pos_bias.type_as(attn)
|
226 |
+
|
227 |
+
if attn_mask is not None:
|
228 |
+
attn_mask = attn_mask.bool()
|
229 |
+
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
|
230 |
+
|
231 |
+
attn = attn.softmax(dim=-1)
|
232 |
+
attn = self.attn_drop(attn)
|
233 |
+
|
234 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
235 |
+
x = self.inner_attn_ln(x)
|
236 |
+
x = self.proj(x)
|
237 |
+
x = self.proj_drop(x)
|
238 |
+
return x
|
239 |
+
|
240 |
+
|
241 |
+
class Block(nn.Module):
|
242 |
+
|
243 |
+
def __init__(
|
244 |
+
self,
|
245 |
+
dim,
|
246 |
+
num_heads,
|
247 |
+
mlp_ratio=4.0,
|
248 |
+
qkv_bias=False,
|
249 |
+
qk_scale=None,
|
250 |
+
drop=0.0,
|
251 |
+
attn_drop=0.0,
|
252 |
+
drop_path=0.0,
|
253 |
+
init_values=None,
|
254 |
+
act_layer=nn.GELU,
|
255 |
+
norm_layer=nn.LayerNorm,
|
256 |
+
window_size=None,
|
257 |
+
attn_head_dim=None,
|
258 |
+
xattn=False,
|
259 |
+
rope=None,
|
260 |
+
postnorm=False,
|
261 |
+
subln=False,
|
262 |
+
naiveswiglu=False,
|
263 |
+
):
|
264 |
+
super().__init__()
|
265 |
+
self.norm1 = norm_layer(dim)
|
266 |
+
self.attn = Attention(
|
267 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim, xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer
|
268 |
+
)
|
269 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
270 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
271 |
+
self.norm2 = norm_layer(dim)
|
272 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
273 |
+
|
274 |
+
if naiveswiglu:
|
275 |
+
self.mlp = SwiGLU(
|
276 |
+
in_features=dim,
|
277 |
+
hidden_features=mlp_hidden_dim,
|
278 |
+
subln=subln,
|
279 |
+
norm_layer=norm_layer,
|
280 |
+
)
|
281 |
+
else:
|
282 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, subln=subln, drop=drop)
|
283 |
+
|
284 |
+
if init_values is not None and init_values > 0:
|
285 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
286 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
287 |
+
else:
|
288 |
+
self.gamma_1, self.gamma_2 = None, None
|
289 |
+
|
290 |
+
self.postnorm = postnorm
|
291 |
+
|
292 |
+
def forward(self, x, rel_pos_bias=None, attn_mask=None):
|
293 |
+
if self.gamma_1 is None:
|
294 |
+
if self.postnorm:
|
295 |
+
x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
|
296 |
+
x = x + self.drop_path(self.norm2(self.mlp(x)))
|
297 |
+
else:
|
298 |
+
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
|
299 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
300 |
+
else:
|
301 |
+
if self.postnorm:
|
302 |
+
x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
|
303 |
+
x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
|
304 |
+
else:
|
305 |
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
|
306 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
307 |
+
return x
|
308 |
+
|
309 |
+
|
310 |
+
class PatchEmbed(nn.Module):
|
311 |
+
"""Image to Patch Embedding"""
|
312 |
+
|
313 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
314 |
+
super().__init__()
|
315 |
+
img_size = to_2tuple(img_size)
|
316 |
+
patch_size = to_2tuple(patch_size)
|
317 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
318 |
+
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
319 |
+
self.img_size = img_size
|
320 |
+
self.patch_size = patch_size
|
321 |
+
self.num_patches = num_patches
|
322 |
+
|
323 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
324 |
+
|
325 |
+
def forward(self, x, **kwargs):
|
326 |
+
B, C, H, W = x.shape
|
327 |
+
# FIXME look at relaxing size constraints
|
328 |
+
assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
329 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
330 |
+
return x
|
331 |
+
|
332 |
+
|
333 |
+
class RelativePositionBias(nn.Module):
|
334 |
+
|
335 |
+
def __init__(self, window_size, num_heads):
|
336 |
+
super().__init__()
|
337 |
+
self.window_size = window_size
|
338 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
339 |
+
self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
340 |
+
# cls to token & token 2 cls & cls to cls
|
341 |
+
|
342 |
+
# get pair-wise relative position index for each token inside the window
|
343 |
+
coords_h = torch.arange(window_size[0])
|
344 |
+
coords_w = torch.arange(window_size[1])
|
345 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
346 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
347 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
348 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
349 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
350 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
351 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
352 |
+
relative_position_index = torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
353 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
354 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
355 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
356 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
357 |
+
|
358 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
359 |
+
|
360 |
+
def forward(self):
|
361 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
362 |
+
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
363 |
+
|
364 |
+
|
365 |
+
class EVAVisionTransformer(nn.Module):
|
366 |
+
"""Vision Transformer with support for patch or hybrid CNN input stage"""
|
367 |
+
|
368 |
+
def __init__(
|
369 |
+
self,
|
370 |
+
img_size=224,
|
371 |
+
patch_size=16,
|
372 |
+
in_chans=3,
|
373 |
+
num_classes=1000,
|
374 |
+
embed_dim=768,
|
375 |
+
depth=12,
|
376 |
+
num_heads=12,
|
377 |
+
mlp_ratio=4.0,
|
378 |
+
qkv_bias=False,
|
379 |
+
qk_scale=None,
|
380 |
+
drop_rate=0.0,
|
381 |
+
attn_drop_rate=0.0,
|
382 |
+
drop_path_rate=0.0,
|
383 |
+
norm_layer=nn.LayerNorm,
|
384 |
+
init_values=None,
|
385 |
+
patch_dropout=0.0,
|
386 |
+
use_abs_pos_emb=True,
|
387 |
+
use_rel_pos_bias=False,
|
388 |
+
use_shared_rel_pos_bias=False,
|
389 |
+
rope=False,
|
390 |
+
use_mean_pooling=True,
|
391 |
+
init_scale=0.001,
|
392 |
+
grad_checkpointing=False,
|
393 |
+
xattn=False,
|
394 |
+
postnorm=False,
|
395 |
+
pt_hw_seq_len=16,
|
396 |
+
intp_freq=False,
|
397 |
+
naiveswiglu=False,
|
398 |
+
subln=False,
|
399 |
+
):
|
400 |
+
super().__init__()
|
401 |
+
self.image_size = img_size
|
402 |
+
self.num_classes = num_classes
|
403 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
404 |
+
|
405 |
+
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
406 |
+
num_patches = self.patch_embed.num_patches
|
407 |
+
|
408 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
409 |
+
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
410 |
+
if use_abs_pos_emb:
|
411 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
412 |
+
else:
|
413 |
+
self.pos_embed = None
|
414 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
415 |
+
|
416 |
+
if use_shared_rel_pos_bias:
|
417 |
+
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
|
418 |
+
else:
|
419 |
+
self.rel_pos_bias = None
|
420 |
+
|
421 |
+
if rope:
|
422 |
+
half_head_dim = embed_dim // num_heads // 2
|
423 |
+
hw_seq_len = img_size // patch_size
|
424 |
+
self.rope = VisionRotaryEmbeddingFast(
|
425 |
+
dim=half_head_dim,
|
426 |
+
pt_seq_len=pt_hw_seq_len,
|
427 |
+
ft_seq_len=hw_seq_len if intp_freq else None,
|
428 |
+
# patch_dropout=patch_dropout
|
429 |
+
)
|
430 |
+
else:
|
431 |
+
self.rope = None
|
432 |
+
|
433 |
+
self.naiveswiglu = naiveswiglu
|
434 |
+
|
435 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
436 |
+
self.use_rel_pos_bias = use_rel_pos_bias
|
437 |
+
self.blocks = nn.ModuleList(
|
438 |
+
[
|
439 |
+
Block(
|
440 |
+
dim=embed_dim,
|
441 |
+
num_heads=num_heads,
|
442 |
+
mlp_ratio=mlp_ratio,
|
443 |
+
qkv_bias=qkv_bias,
|
444 |
+
qk_scale=qk_scale,
|
445 |
+
drop=drop_rate,
|
446 |
+
attn_drop=attn_drop_rate,
|
447 |
+
drop_path=dpr[i],
|
448 |
+
norm_layer=norm_layer,
|
449 |
+
init_values=init_values,
|
450 |
+
window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
|
451 |
+
xattn=xattn,
|
452 |
+
rope=self.rope,
|
453 |
+
postnorm=postnorm,
|
454 |
+
subln=subln,
|
455 |
+
naiveswiglu=naiveswiglu,
|
456 |
+
)
|
457 |
+
for i in range(depth)
|
458 |
+
]
|
459 |
+
)
|
460 |
+
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
461 |
+
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
462 |
+
self.head = nn.Linear(embed_dim, num_classes, bias=qkv_bias) if num_classes > 0 else nn.Identity()
|
463 |
+
|
464 |
+
if self.pos_embed is not None:
|
465 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
466 |
+
|
467 |
+
trunc_normal_(self.cls_token, std=0.02)
|
468 |
+
|
469 |
+
self.apply(self._init_weights)
|
470 |
+
self.fix_init_weight()
|
471 |
+
|
472 |
+
if isinstance(self.head, nn.Linear):
|
473 |
+
trunc_normal_(self.head.weight, std=0.02)
|
474 |
+
self.head.weight.data.mul_(init_scale)
|
475 |
+
if self.head.bias is not None:
|
476 |
+
self.head.bias.data.mul_(init_scale)
|
477 |
+
|
478 |
+
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
|
479 |
+
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity()
|
480 |
+
|
481 |
+
self.grad_checkpointing = grad_checkpointing
|
482 |
+
|
483 |
+
def fix_init_weight(self):
|
484 |
+
def rescale(param, layer_id):
|
485 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
486 |
+
|
487 |
+
for layer_id, layer in enumerate(self.blocks):
|
488 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
489 |
+
if self.naiveswiglu:
|
490 |
+
rescale(layer.mlp.w3.weight.data, layer_id + 1)
|
491 |
+
else:
|
492 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
493 |
+
|
494 |
+
def get_cast_dtype(self) -> torch.dtype:
|
495 |
+
return self.blocks[0].mlp.fc2.weight.dtype
|
496 |
+
|
497 |
+
def _init_weights(self, m):
|
498 |
+
if isinstance(m, nn.Linear):
|
499 |
+
trunc_normal_(m.weight, std=0.02)
|
500 |
+
if m.bias is not None:
|
501 |
+
nn.init.constant_(m.bias, 0)
|
502 |
+
elif isinstance(m, nn.LayerNorm):
|
503 |
+
nn.init.constant_(m.bias, 0)
|
504 |
+
nn.init.constant_(m.weight, 1.0)
|
505 |
+
|
506 |
+
def get_num_layers(self):
|
507 |
+
return len(self.blocks)
|
508 |
+
|
509 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
510 |
+
assert unlocked_groups == 0, "partial locking not currently supported for this model"
|
511 |
+
for param in self.parameters():
|
512 |
+
param.requires_grad = False
|
513 |
+
|
514 |
+
@torch.jit.ignore
|
515 |
+
def set_grad_checkpointing(self, enable=True):
|
516 |
+
self.grad_checkpointing = enable
|
517 |
+
|
518 |
+
@torch.jit.ignore
|
519 |
+
def no_weight_decay(self):
|
520 |
+
return {"pos_embed", "cls_token"}
|
521 |
+
|
522 |
+
def get_classifier(self):
|
523 |
+
return self.head
|
524 |
+
|
525 |
+
def reset_classifier(self, num_classes, global_pool=""):
|
526 |
+
self.num_classes = num_classes
|
527 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
528 |
+
|
529 |
+
def forward_features(self, x, return_all_features=False):
|
530 |
+
|
531 |
+
x = self.patch_embed(x)
|
532 |
+
batch_size, seq_len, _ = x.size()
|
533 |
+
|
534 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
535 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
536 |
+
if self.pos_embed is not None:
|
537 |
+
x = x + self.pos_embed
|
538 |
+
x = self.pos_drop(x)
|
539 |
+
|
540 |
+
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
541 |
+
# if os.getenv("RoPE") == "1":
|
542 |
+
# if self.training and not isinstance(self.patch_dropout, nn.Identity):
|
543 |
+
# x, patch_indices_keep = self.patch_dropout(x)
|
544 |
+
# self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
|
545 |
+
# else:
|
546 |
+
# self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
|
547 |
+
# x = self.patch_dropout(x)
|
548 |
+
# else:
|
549 |
+
x = self.patch_dropout(x)
|
550 |
+
|
551 |
+
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
552 |
+
for blk in self.blocks:
|
553 |
+
if self.grad_checkpointing:
|
554 |
+
x = checkpoint(blk, x, (rel_pos_bias,))
|
555 |
+
else:
|
556 |
+
x = blk(x, rel_pos_bias=rel_pos_bias)
|
557 |
+
|
558 |
+
if not return_all_features:
|
559 |
+
x = self.norm(x)
|
560 |
+
if self.fc_norm is not None:
|
561 |
+
return self.fc_norm(x.mean(1))
|
562 |
+
else:
|
563 |
+
return x[:, 0]
|
564 |
+
return x
|
565 |
+
|
566 |
+
def forward(self, x, return_all_features=False):
|
567 |
+
if return_all_features:
|
568 |
+
return self.forward_features(x, return_all_features)
|
569 |
+
x = self.forward_features(x)
|
570 |
+
x = self.head(x)
|
571 |
+
return x
|
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/factory.py
ADDED
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import pathlib
|
5 |
+
import re
|
6 |
+
from copy import deepcopy
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Optional, Tuple, Union, Dict, Any
|
9 |
+
import torch
|
10 |
+
|
11 |
+
try:
|
12 |
+
import deepspeed
|
13 |
+
except ImportError:
|
14 |
+
deepspeed = None
|
15 |
+
|
16 |
+
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
17 |
+
from .model import CLIP, CustomCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict, get_cast_dtype
|
18 |
+
from .openai import load_openai_model
|
19 |
+
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
|
20 |
+
from .transform import image_transform
|
21 |
+
from .tokenizer import HFTokenizer, tokenize
|
22 |
+
from .utils import resize_clip_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed, resize_eva_pos_embed
|
23 |
+
|
24 |
+
|
25 |
+
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
|
26 |
+
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
|
27 |
+
|
28 |
+
|
29 |
+
def _natural_key(string_):
|
30 |
+
return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
|
31 |
+
|
32 |
+
|
33 |
+
def _rescan_model_configs():
|
34 |
+
global _MODEL_CONFIGS
|
35 |
+
|
36 |
+
config_ext = (".json",)
|
37 |
+
config_files = []
|
38 |
+
for config_path in _MODEL_CONFIG_PATHS:
|
39 |
+
if config_path.is_file() and config_path.suffix in config_ext:
|
40 |
+
config_files.append(config_path)
|
41 |
+
elif config_path.is_dir():
|
42 |
+
for ext in config_ext:
|
43 |
+
config_files.extend(config_path.glob(f"*{ext}"))
|
44 |
+
|
45 |
+
for cf in config_files:
|
46 |
+
with open(cf, "r", encoding="utf8") as f:
|
47 |
+
model_cfg = json.load(f)
|
48 |
+
if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")):
|
49 |
+
_MODEL_CONFIGS[cf.stem] = model_cfg
|
50 |
+
|
51 |
+
_MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])))
|
52 |
+
|
53 |
+
|
54 |
+
_rescan_model_configs() # initial populate of model config registry
|
55 |
+
|
56 |
+
|
57 |
+
def list_models():
|
58 |
+
"""enumerate available model architectures based on config files"""
|
59 |
+
return list(_MODEL_CONFIGS.keys())
|
60 |
+
|
61 |
+
|
62 |
+
def add_model_config(path):
|
63 |
+
"""add model config path or file and update registry"""
|
64 |
+
if not isinstance(path, Path):
|
65 |
+
path = Path(path)
|
66 |
+
_MODEL_CONFIG_PATHS.append(path)
|
67 |
+
_rescan_model_configs()
|
68 |
+
|
69 |
+
|
70 |
+
def get_model_config(model_name):
|
71 |
+
if model_name in _MODEL_CONFIGS:
|
72 |
+
return deepcopy(_MODEL_CONFIGS[model_name])
|
73 |
+
else:
|
74 |
+
return None
|
75 |
+
|
76 |
+
|
77 |
+
def get_tokenizer(model_name):
|
78 |
+
config = get_model_config(model_name)
|
79 |
+
tokenizer = HFTokenizer(config["text_cfg"]["hf_tokenizer_name"]) if "hf_tokenizer_name" in config["text_cfg"] else tokenize
|
80 |
+
return tokenizer
|
81 |
+
|
82 |
+
|
83 |
+
# loading openai CLIP weights when is_openai=True for training
|
84 |
+
def load_state_dict(checkpoint_path: str, map_location: str = "cpu", model_key: str = "model|module|state_dict", is_openai: bool = False, skip_list: list = []):
|
85 |
+
if is_openai:
|
86 |
+
model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
|
87 |
+
state_dict = model.state_dict()
|
88 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
89 |
+
state_dict.pop(key, None)
|
90 |
+
else:
|
91 |
+
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
92 |
+
for mk in model_key.split("|"):
|
93 |
+
if isinstance(checkpoint, dict) and mk in checkpoint:
|
94 |
+
state_dict = checkpoint[mk]
|
95 |
+
break
|
96 |
+
else:
|
97 |
+
state_dict = checkpoint
|
98 |
+
if next(iter(state_dict.items()))[0].startswith("module"):
|
99 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
100 |
+
|
101 |
+
for k in skip_list:
|
102 |
+
if k in list(state_dict.keys()):
|
103 |
+
logging.info(f"Removing key {k} from pretrained checkpoint")
|
104 |
+
del state_dict[k]
|
105 |
+
|
106 |
+
if os.getenv("RoPE") == "1":
|
107 |
+
for k in list(state_dict.keys()):
|
108 |
+
if "freqs_cos" in k or "freqs_sin" in k:
|
109 |
+
del state_dict[k]
|
110 |
+
return state_dict
|
111 |
+
|
112 |
+
|
113 |
+
def load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=True):
|
114 |
+
state_dict = load_state_dict(checkpoint_path, model_key=model_key, is_openai=False)
|
115 |
+
# detect old format and make compatible with new format
|
116 |
+
if "positional_embedding" in state_dict and not hasattr(model, "positional_embedding"):
|
117 |
+
state_dict = convert_to_custom_text_state_dict(state_dict)
|
118 |
+
if "text.logit_scale" in state_dict and hasattr(model, "logit_scale"):
|
119 |
+
state_dict["logit_scale"] = state_dict["text.logit_scale"]
|
120 |
+
del state_dict["text.logit_scale"]
|
121 |
+
|
122 |
+
# resize_clip_pos_embed for CLIP and open CLIP
|
123 |
+
if "visual.positional_embedding" in state_dict:
|
124 |
+
resize_clip_pos_embed(state_dict, model)
|
125 |
+
# specified to eva_vit_model
|
126 |
+
elif "visual.pos_embed" in state_dict:
|
127 |
+
resize_evaclip_pos_embed(state_dict, model)
|
128 |
+
|
129 |
+
# resize_clip_pos_embed(state_dict, model)
|
130 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
131 |
+
logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}")
|
132 |
+
return incompatible_keys
|
133 |
+
|
134 |
+
|
135 |
+
def load_clip_visual_state_dict(checkpoint_path: str, map_location: str = "cpu", is_openai: bool = False, skip_list: list = []):
|
136 |
+
state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
|
137 |
+
|
138 |
+
for k in list(state_dict.keys()):
|
139 |
+
if not k.startswith("visual."):
|
140 |
+
del state_dict[k]
|
141 |
+
for k in list(state_dict.keys()):
|
142 |
+
if k.startswith("visual."):
|
143 |
+
new_k = k[7:]
|
144 |
+
state_dict[new_k] = state_dict[k]
|
145 |
+
del state_dict[k]
|
146 |
+
return state_dict
|
147 |
+
|
148 |
+
|
149 |
+
def load_clip_text_state_dict(checkpoint_path: str, map_location: str = "cpu", is_openai: bool = False, skip_list: list = []):
|
150 |
+
state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
|
151 |
+
|
152 |
+
for k in list(state_dict.keys()):
|
153 |
+
if k.startswith("visual."):
|
154 |
+
del state_dict[k]
|
155 |
+
return state_dict
|
156 |
+
|
157 |
+
|
158 |
+
def get_pretrained_tag(pretrained_model):
|
159 |
+
pretrained_model = pretrained_model.lower()
|
160 |
+
if "laion" in pretrained_model or "open_clip" in pretrained_model:
|
161 |
+
return "open_clip"
|
162 |
+
elif "openai" in pretrained_model:
|
163 |
+
return "clip"
|
164 |
+
elif "eva" in pretrained_model and "clip" in pretrained_model:
|
165 |
+
return "eva_clip"
|
166 |
+
else:
|
167 |
+
return "other"
|
168 |
+
|
169 |
+
|
170 |
+
def load_zero_partitions(model, state_dict, is_deepspeed_zero3_enabled, pretrained_model_path, ignore_mismatched_sizes=False):
|
171 |
+
"""
|
172 |
+
adept from pytorch lightning and transformers
|
173 |
+
with deepspeed.zero.Init():
|
174 |
+
model = MyModel()
|
175 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
176 |
+
load_zero_partitions(model, prefix="")
|
177 |
+
"""
|
178 |
+
|
179 |
+
# because zero3 puts placeholders in model params, this context
|
180 |
+
# manager gathers (unpartitions) the params of the current layer, then loads from
|
181 |
+
# the state dict and then re-partitions them again
|
182 |
+
model_state_dict = model.state_dict()
|
183 |
+
expected_keys = list(model_state_dict.keys())
|
184 |
+
loaded_keys = list(state_dict.keys())
|
185 |
+
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
186 |
+
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
187 |
+
|
188 |
+
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
189 |
+
# matching the weights in the model.
|
190 |
+
mismatched_keys = []
|
191 |
+
if ignore_mismatched_sizes:
|
192 |
+
for checkpoint_key in loaded_keys:
|
193 |
+
model_key = checkpoint_key
|
194 |
+
|
195 |
+
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
|
196 |
+
mismatched_keys.append((checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape))
|
197 |
+
del state_dict[checkpoint_key]
|
198 |
+
# copy state_dict so _load_from_state_dict can modify it
|
199 |
+
metadata = getattr(state_dict, "_metadata", None)
|
200 |
+
state_dict = state_dict.copy()
|
201 |
+
if metadata is not None:
|
202 |
+
state_dict._metadata = metadata
|
203 |
+
|
204 |
+
error_msgs = []
|
205 |
+
|
206 |
+
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
207 |
+
# so we need to apply the function recursively.
|
208 |
+
def load(module, prefix=""):
|
209 |
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
210 |
+
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
211 |
+
if is_deepspeed_zero3_enabled:
|
212 |
+
# because zero3 puts placeholders in model params, this context
|
213 |
+
# manager gathers (unpartitions) the params of the current layer, then loads from
|
214 |
+
# the state dict and then re-partitions them again
|
215 |
+
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
|
216 |
+
if torch.distributed.get_rank() == 0:
|
217 |
+
module._load_from_state_dict(*args)
|
218 |
+
else:
|
219 |
+
module._load_from_state_dict(*args)
|
220 |
+
|
221 |
+
for name, child in module._modules.items():
|
222 |
+
if child is not None:
|
223 |
+
load(child, prefix + name + ".")
|
224 |
+
|
225 |
+
# Make sure we are able to load base models as well as derived models (with heads)
|
226 |
+
start_prefix = ""
|
227 |
+
model_to_load = model
|
228 |
+
load(model_to_load, prefix=start_prefix)
|
229 |
+
del state_dict
|
230 |
+
if len(error_msgs) > 0:
|
231 |
+
error_msg = "\n\t".join(error_msgs)
|
232 |
+
if "size mismatch" in error_msg:
|
233 |
+
error_msg += "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
|
234 |
+
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
235 |
+
if len(unexpected_keys) > 0:
|
236 |
+
logging.warning(
|
237 |
+
f"Some weights of the model checkpoint at {pretrained_model_path} were not used when"
|
238 |
+
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
239 |
+
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
|
240 |
+
" with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
241 |
+
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
242 |
+
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
|
243 |
+
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
244 |
+
)
|
245 |
+
else:
|
246 |
+
logging.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
247 |
+
if len(missing_keys) > 0:
|
248 |
+
logging.warning(
|
249 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
250 |
+
f" {pretrained_model_path} and are newly initialized: {missing_keys}\nYou should probably"
|
251 |
+
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
252 |
+
)
|
253 |
+
elif len(mismatched_keys) == 0:
|
254 |
+
logging.info(
|
255 |
+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
256 |
+
f" {pretrained_model_path}.\nIf your task is similar to the task the model of the checkpoint"
|
257 |
+
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
|
258 |
+
" training."
|
259 |
+
)
|
260 |
+
if len(mismatched_keys) > 0:
|
261 |
+
mismatched_warning = "\n".join([f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" for key, shape1, shape2 in mismatched_keys])
|
262 |
+
logging.warning(
|
263 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
264 |
+
f" {pretrained_model_path} and are newly initialized because the shapes did not"
|
265 |
+
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
|
266 |
+
" to use it for predictions and inference."
|
267 |
+
)
|
268 |
+
|
269 |
+
|
270 |
+
def load_pretrained_checkpoint(model, visual_checkpoint_path, text_checkpoint_path, strict=True, visual_model=None, text_model=None, model_key="model|module|state_dict", skip_list=[]):
|
271 |
+
visual_tag = get_pretrained_tag(visual_model)
|
272 |
+
text_tag = get_pretrained_tag(text_model)
|
273 |
+
|
274 |
+
logging.info(f"num of model state_dict keys: {len(model.state_dict().keys())}")
|
275 |
+
visual_incompatible_keys, text_incompatible_keys = None, None
|
276 |
+
if visual_checkpoint_path:
|
277 |
+
if visual_tag == "eva_clip" or visual_tag == "open_clip":
|
278 |
+
visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=False, skip_list=skip_list)
|
279 |
+
elif visual_tag == "clip":
|
280 |
+
visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list)
|
281 |
+
else:
|
282 |
+
visual_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
|
283 |
+
|
284 |
+
# resize_clip_pos_embed for CLIP and open CLIP
|
285 |
+
if "positional_embedding" in visual_state_dict:
|
286 |
+
resize_visual_pos_embed(visual_state_dict, model)
|
287 |
+
# specified to EVA model
|
288 |
+
elif "pos_embed" in visual_state_dict:
|
289 |
+
resize_eva_pos_embed(visual_state_dict, model)
|
290 |
+
|
291 |
+
visual_incompatible_keys = model.visual.load_state_dict(visual_state_dict, strict=strict)
|
292 |
+
logging.info(f"num of loaded visual_state_dict keys: {len(visual_state_dict.keys())}")
|
293 |
+
logging.info(f"visual_incompatible_keys.missing_keys: {visual_incompatible_keys.missing_keys}")
|
294 |
+
|
295 |
+
if text_checkpoint_path:
|
296 |
+
if text_tag == "eva_clip" or text_tag == "open_clip":
|
297 |
+
text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=False, skip_list=skip_list)
|
298 |
+
elif text_tag == "clip":
|
299 |
+
text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=True, skip_list=skip_list)
|
300 |
+
else:
|
301 |
+
text_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
|
302 |
+
|
303 |
+
text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict)
|
304 |
+
|
305 |
+
logging.info(f"num of loaded text_state_dict keys: {len(text_state_dict.keys())}")
|
306 |
+
logging.info(f"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}")
|
307 |
+
|
308 |
+
return visual_incompatible_keys, text_incompatible_keys
|
309 |
+
|
310 |
+
|
311 |
+
def create_model(
|
312 |
+
model_name: str,
|
313 |
+
pretrained: Optional[str] = None,
|
314 |
+
precision: str = "fp32",
|
315 |
+
device: Union[str, torch.device] = "cpu",
|
316 |
+
jit: bool = False,
|
317 |
+
force_quick_gelu: bool = False,
|
318 |
+
force_custom_clip: bool = False,
|
319 |
+
force_patch_dropout: Optional[float] = None,
|
320 |
+
pretrained_image: str = "",
|
321 |
+
pretrained_text: str = "",
|
322 |
+
pretrained_hf: bool = True,
|
323 |
+
pretrained_visual_model: str = None,
|
324 |
+
pretrained_text_model: str = None,
|
325 |
+
cache_dir: Optional[str] = None,
|
326 |
+
skip_list: list = [],
|
327 |
+
):
|
328 |
+
model_name = model_name.replace("/", "-") # for callers using old naming with / in ViT names
|
329 |
+
if isinstance(device, str):
|
330 |
+
device = torch.device(device)
|
331 |
+
|
332 |
+
if pretrained and pretrained.lower() == "openai":
|
333 |
+
logging.info(f"Loading pretrained {model_name} from OpenAI.")
|
334 |
+
model = load_openai_model(
|
335 |
+
model_name,
|
336 |
+
precision=precision,
|
337 |
+
device=device,
|
338 |
+
jit=jit,
|
339 |
+
cache_dir=cache_dir,
|
340 |
+
)
|
341 |
+
else:
|
342 |
+
model_cfg = get_model_config(model_name)
|
343 |
+
if model_cfg is not None:
|
344 |
+
logging.info(f"Loaded {model_name} model config.")
|
345 |
+
else:
|
346 |
+
logging.error(f"Model config for {model_name} not found; available models {list_models()}.")
|
347 |
+
raise RuntimeError(f"Model config for {model_name} not found.")
|
348 |
+
|
349 |
+
if "rope" in model_cfg.get("vision_cfg", {}):
|
350 |
+
if model_cfg["vision_cfg"]["rope"]:
|
351 |
+
os.environ["RoPE"] = "1"
|
352 |
+
else:
|
353 |
+
os.environ["RoPE"] = "0"
|
354 |
+
|
355 |
+
if force_quick_gelu:
|
356 |
+
# override for use of QuickGELU on non-OpenAI transformer models
|
357 |
+
model_cfg["quick_gelu"] = True
|
358 |
+
|
359 |
+
if force_patch_dropout is not None:
|
360 |
+
# override the default patch dropout value
|
361 |
+
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
|
362 |
+
|
363 |
+
cast_dtype = get_cast_dtype(precision)
|
364 |
+
custom_clip = model_cfg.pop("custom_text", False) or force_custom_clip or ("hf_model_name" in model_cfg["text_cfg"])
|
365 |
+
|
366 |
+
if custom_clip:
|
367 |
+
if "hf_model_name" in model_cfg.get("text_cfg", {}):
|
368 |
+
model_cfg["text_cfg"]["hf_model_pretrained"] = pretrained_hf
|
369 |
+
model = CustomCLIP(**model_cfg, cast_dtype=cast_dtype)
|
370 |
+
else:
|
371 |
+
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
|
372 |
+
|
373 |
+
pretrained_cfg = {}
|
374 |
+
if pretrained:
|
375 |
+
checkpoint_path = ""
|
376 |
+
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
|
377 |
+
if pretrained_cfg:
|
378 |
+
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
|
379 |
+
elif os.path.exists(pretrained):
|
380 |
+
checkpoint_path = pretrained
|
381 |
+
|
382 |
+
if checkpoint_path:
|
383 |
+
logging.info(f"Loading pretrained {model_name} weights ({pretrained}).")
|
384 |
+
load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=False)
|
385 |
+
else:
|
386 |
+
error_str = f"Pretrained weights ({pretrained}) not found for model {model_name}." f"Available pretrained tags ({list_pretrained_tags_by_model(model_name)}."
|
387 |
+
logging.warning(error_str)
|
388 |
+
raise RuntimeError(error_str)
|
389 |
+
else:
|
390 |
+
visual_checkpoint_path = ""
|
391 |
+
text_checkpoint_path = ""
|
392 |
+
|
393 |
+
if pretrained_image:
|
394 |
+
pretrained_visual_model = pretrained_visual_model.replace("/", "-") # for callers using old naming with / in ViT names
|
395 |
+
pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image)
|
396 |
+
if "timm_model_name" in model_cfg.get("vision_cfg", {}):
|
397 |
+
# pretrained weight loading for timm models set via vision_cfg
|
398 |
+
model_cfg["vision_cfg"]["timm_model_pretrained"] = True
|
399 |
+
elif pretrained_image_cfg:
|
400 |
+
visual_checkpoint_path = download_pretrained(pretrained_image_cfg, cache_dir=cache_dir)
|
401 |
+
elif os.path.exists(pretrained_image):
|
402 |
+
visual_checkpoint_path = pretrained_image
|
403 |
+
else:
|
404 |
+
logging.warning(f"Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.")
|
405 |
+
raise RuntimeError(f"Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.")
|
406 |
+
|
407 |
+
if pretrained_text:
|
408 |
+
pretrained_text_model = pretrained_text_model.replace("/", "-") # for callers using old naming with / in ViT names
|
409 |
+
pretrained_text_cfg = get_pretrained_cfg(pretrained_text_model, pretrained_text)
|
410 |
+
if pretrained_image_cfg:
|
411 |
+
text_checkpoint_path = download_pretrained(pretrained_text_cfg, cache_dir=cache_dir)
|
412 |
+
elif os.path.exists(pretrained_text):
|
413 |
+
text_checkpoint_path = pretrained_text
|
414 |
+
else:
|
415 |
+
logging.warning(f"Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.")
|
416 |
+
raise RuntimeError(f"Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.")
|
417 |
+
|
418 |
+
if visual_checkpoint_path:
|
419 |
+
logging.info(f"Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).")
|
420 |
+
if text_checkpoint_path:
|
421 |
+
logging.info(f"Loading pretrained {model_name}.text weights ({text_checkpoint_path}).")
|
422 |
+
|
423 |
+
if visual_checkpoint_path or text_checkpoint_path:
|
424 |
+
load_pretrained_checkpoint(model, visual_checkpoint_path, text_checkpoint_path, strict=False, visual_model=pretrained_visual_model, text_model=pretrained_text_model, model_key="model|module|state_dict", skip_list=skip_list)
|
425 |
+
|
426 |
+
if "fp16" in precision or "bf16" in precision:
|
427 |
+
logging.info(f"convert precision to {precision}")
|
428 |
+
model = model.to(torch.bfloat16) if "bf16" in precision else model.to(torch.float16)
|
429 |
+
|
430 |
+
# model.to(device=device)
|
431 |
+
|
432 |
+
# set image / mean metadata from pretrained_cfg if available, or use default
|
433 |
+
model.visual.image_mean = pretrained_cfg.get("mean", None) or OPENAI_DATASET_MEAN
|
434 |
+
model.visual.image_std = pretrained_cfg.get("std", None) or OPENAI_DATASET_STD
|
435 |
+
|
436 |
+
if jit:
|
437 |
+
model = torch.jit.script(model)
|
438 |
+
|
439 |
+
return model
|
440 |
+
|
441 |
+
|
442 |
+
def create_model_and_transforms(
|
443 |
+
model_name: str,
|
444 |
+
pretrained: Optional[str] = None,
|
445 |
+
precision: str = "fp32",
|
446 |
+
device: Union[str, torch.device] = "cpu",
|
447 |
+
jit: bool = False,
|
448 |
+
force_quick_gelu: bool = False,
|
449 |
+
force_custom_clip: bool = False,
|
450 |
+
force_patch_dropout: Optional[float] = None,
|
451 |
+
pretrained_image: str = "",
|
452 |
+
pretrained_text: str = "",
|
453 |
+
pretrained_hf: bool = True,
|
454 |
+
pretrained_visual_model: str = None,
|
455 |
+
pretrained_text_model: str = None,
|
456 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
457 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
458 |
+
cache_dir: Optional[str] = None,
|
459 |
+
skip_list: list = [],
|
460 |
+
):
|
461 |
+
model = create_model(
|
462 |
+
model_name,
|
463 |
+
pretrained,
|
464 |
+
precision=precision,
|
465 |
+
device=device,
|
466 |
+
jit=jit,
|
467 |
+
force_quick_gelu=force_quick_gelu,
|
468 |
+
force_custom_clip=force_custom_clip,
|
469 |
+
force_patch_dropout=force_patch_dropout,
|
470 |
+
pretrained_image=pretrained_image,
|
471 |
+
pretrained_text=pretrained_text,
|
472 |
+
pretrained_hf=pretrained_hf,
|
473 |
+
pretrained_visual_model=pretrained_visual_model,
|
474 |
+
pretrained_text_model=pretrained_text_model,
|
475 |
+
cache_dir=cache_dir,
|
476 |
+
skip_list=skip_list,
|
477 |
+
)
|
478 |
+
|
479 |
+
image_mean = image_mean or getattr(model.visual, "image_mean", None)
|
480 |
+
image_std = image_std or getattr(model.visual, "image_std", None)
|
481 |
+
preprocess_train = image_transform(model.visual.image_size, is_train=True, mean=image_mean, std=image_std)
|
482 |
+
preprocess_val = image_transform(model.visual.image_size, is_train=False, mean=image_mean, std=image_std)
|
483 |
+
|
484 |
+
return model, preprocess_train, preprocess_val
|
485 |
+
|
486 |
+
|
487 |
+
def create_model_from_pretrained(
|
488 |
+
model_name: str,
|
489 |
+
pretrained: str,
|
490 |
+
precision: str = "fp32",
|
491 |
+
device: Union[str, torch.device] = "cpu",
|
492 |
+
jit: bool = False,
|
493 |
+
force_quick_gelu: bool = False,
|
494 |
+
force_custom_clip: bool = False,
|
495 |
+
force_patch_dropout: Optional[float] = None,
|
496 |
+
return_transform: bool = True,
|
497 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
498 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
499 |
+
cache_dir: Optional[str] = None,
|
500 |
+
is_frozen: bool = False,
|
501 |
+
):
|
502 |
+
if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):
|
503 |
+
raise RuntimeError(f"{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}." f" Use open_clip.list_pretrained() to find one.")
|
504 |
+
|
505 |
+
model = create_model(
|
506 |
+
model_name,
|
507 |
+
pretrained,
|
508 |
+
precision=precision,
|
509 |
+
device=device,
|
510 |
+
jit=jit,
|
511 |
+
force_quick_gelu=force_quick_gelu,
|
512 |
+
force_custom_clip=force_custom_clip,
|
513 |
+
force_patch_dropout=force_patch_dropout,
|
514 |
+
cache_dir=cache_dir,
|
515 |
+
)
|
516 |
+
|
517 |
+
if is_frozen:
|
518 |
+
for param in model.parameters():
|
519 |
+
param.requires_grad = False
|
520 |
+
|
521 |
+
if not return_transform:
|
522 |
+
return model
|
523 |
+
|
524 |
+
image_mean = image_mean or getattr(model.visual, "image_mean", None)
|
525 |
+
image_std = image_std or getattr(model.visual, "image_std", None)
|
526 |
+
preprocess = image_transform(model.visual.image_size, is_train=False, mean=image_mean, std=image_std)
|
527 |
+
|
528 |
+
return model, preprocess
|
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_configs.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# HF architecture dict:
|
2 |
+
arch_dict = {
|
3 |
+
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
|
4 |
+
"roberta": {
|
5 |
+
"config_names": {
|
6 |
+
"context_length": "max_position_embeddings",
|
7 |
+
"vocab_size": "vocab_size",
|
8 |
+
"width": "hidden_size",
|
9 |
+
"heads": "num_attention_heads",
|
10 |
+
"layers": "num_hidden_layers",
|
11 |
+
"layer_attr": "layer",
|
12 |
+
"token_embeddings_attr": "embeddings",
|
13 |
+
},
|
14 |
+
"pooler": "mean_pooler",
|
15 |
+
},
|
16 |
+
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
|
17 |
+
"xlm-roberta": {
|
18 |
+
"config_names": {
|
19 |
+
"context_length": "max_position_embeddings",
|
20 |
+
"vocab_size": "vocab_size",
|
21 |
+
"width": "hidden_size",
|
22 |
+
"heads": "num_attention_heads",
|
23 |
+
"layers": "num_hidden_layers",
|
24 |
+
"layer_attr": "layer",
|
25 |
+
"token_embeddings_attr": "embeddings",
|
26 |
+
},
|
27 |
+
"pooler": "mean_pooler",
|
28 |
+
},
|
29 |
+
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
|
30 |
+
"mt5": {
|
31 |
+
"config_names": {
|
32 |
+
# unlimited seqlen
|
33 |
+
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
|
34 |
+
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
|
35 |
+
"context_length": "",
|
36 |
+
"vocab_size": "vocab_size",
|
37 |
+
"width": "d_model",
|
38 |
+
"heads": "num_heads",
|
39 |
+
"layers": "num_layers",
|
40 |
+
"layer_attr": "block",
|
41 |
+
"token_embeddings_attr": "embed_tokens",
|
42 |
+
},
|
43 |
+
"pooler": "mean_pooler",
|
44 |
+
},
|
45 |
+
"bert": {
|
46 |
+
"config_names": {
|
47 |
+
"context_length": "max_position_embeddings",
|
48 |
+
"vocab_size": "vocab_size",
|
49 |
+
"width": "hidden_size",
|
50 |
+
"heads": "num_attention_heads",
|
51 |
+
"layers": "num_hidden_layers",
|
52 |
+
"layer_attr": "layer",
|
53 |
+
"token_embeddings_attr": "embeddings",
|
54 |
+
},
|
55 |
+
"pooler": "mean_pooler",
|
56 |
+
},
|
57 |
+
}
|
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_model.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" huggingface model adapter
|
2 |
+
|
3 |
+
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import re
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.nn import functional as F
|
11 |
+
from torch import TensorType
|
12 |
+
|
13 |
+
try:
|
14 |
+
import transformers
|
15 |
+
from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig
|
16 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions
|
17 |
+
except ImportError as e:
|
18 |
+
transformers = None
|
19 |
+
|
20 |
+
class BaseModelOutput:
|
21 |
+
pass
|
22 |
+
|
23 |
+
class PretrainedConfig:
|
24 |
+
pass
|
25 |
+
|
26 |
+
|
27 |
+
from .hf_configs import arch_dict
|
28 |
+
|
29 |
+
|
30 |
+
# utils
|
31 |
+
def _camel2snake(s):
|
32 |
+
return re.sub(r"(?<!^)(?=[A-Z])", "_", s).lower()
|
33 |
+
|
34 |
+
|
35 |
+
# TODO: ?last - for gpt-like models
|
36 |
+
_POOLERS = {}
|
37 |
+
|
38 |
+
|
39 |
+
def register_pooler(cls):
|
40 |
+
"""Decorator registering pooler class"""
|
41 |
+
_POOLERS[_camel2snake(cls.__name__)] = cls
|
42 |
+
return cls
|
43 |
+
|
44 |
+
|
45 |
+
@register_pooler
|
46 |
+
class MeanPooler(nn.Module):
|
47 |
+
"""Mean pooling"""
|
48 |
+
|
49 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
50 |
+
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
|
51 |
+
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
|
52 |
+
|
53 |
+
|
54 |
+
@register_pooler
|
55 |
+
class MaxPooler(nn.Module):
|
56 |
+
"""Max pooling"""
|
57 |
+
|
58 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
59 |
+
masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
|
60 |
+
return masked_output.max(1).values
|
61 |
+
|
62 |
+
|
63 |
+
@register_pooler
|
64 |
+
class ClsPooler(nn.Module):
|
65 |
+
"""CLS token pooling"""
|
66 |
+
|
67 |
+
def __init__(self, use_pooler_output=True):
|
68 |
+
super().__init__()
|
69 |
+
self.cls_token_position = 0
|
70 |
+
self.use_pooler_output = use_pooler_output
|
71 |
+
|
72 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
73 |
+
|
74 |
+
if self.use_pooler_output and isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and (x.pooler_output is not None):
|
75 |
+
return x.pooler_output
|
76 |
+
|
77 |
+
return x.last_hidden_state[:, self.cls_token_position, :]
|
78 |
+
|
79 |
+
|
80 |
+
class HFTextEncoder(nn.Module):
|
81 |
+
"""HuggingFace model adapter"""
|
82 |
+
|
83 |
+
def __init__(self, model_name_or_path: str, output_dim: int, tokenizer_name: str = None, config: PretrainedConfig = None, pooler_type: str = None, proj: str = None, pretrained: bool = True, masked_language_modeling: bool = False):
|
84 |
+
super().__init__()
|
85 |
+
|
86 |
+
self.output_dim = output_dim
|
87 |
+
|
88 |
+
# TODO: find better way to get this information
|
89 |
+
uses_transformer_pooler = pooler_type == "cls_pooler"
|
90 |
+
|
91 |
+
if transformers is None:
|
92 |
+
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
|
93 |
+
if config is None:
|
94 |
+
self.config = AutoConfig.from_pretrained(model_name_or_path)
|
95 |
+
if masked_language_modeling:
|
96 |
+
create_func, model_args = (AutoModelForMaskedLM.from_pretrained, model_name_or_path) if pretrained else (AutoModelForMaskedLM.from_config, self.config)
|
97 |
+
else:
|
98 |
+
create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (AutoModel.from_config, self.config)
|
99 |
+
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
|
100 |
+
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
|
101 |
+
self.transformer = create_func(model_args)
|
102 |
+
self.transformer = self.transformer.encoder
|
103 |
+
else:
|
104 |
+
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
|
105 |
+
else:
|
106 |
+
self.config = config
|
107 |
+
if masked_language_modeling:
|
108 |
+
self.transformer = AutoModelForMaskedLM.from_config(config)
|
109 |
+
else:
|
110 |
+
self.transformer = AutoModel.from_config(config)
|
111 |
+
|
112 |
+
if pooler_type is None: # get default arch pooler
|
113 |
+
self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
|
114 |
+
else:
|
115 |
+
self.pooler = _POOLERS[pooler_type]()
|
116 |
+
|
117 |
+
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
|
118 |
+
if (d_model == output_dim) and (proj is None): # do we always need a proj?
|
119 |
+
self.proj = nn.Identity()
|
120 |
+
elif proj == "linear":
|
121 |
+
self.proj = nn.Linear(d_model, output_dim, bias=False)
|
122 |
+
elif proj == "mlp":
|
123 |
+
hidden_size = (d_model + output_dim) // 2
|
124 |
+
self.proj = nn.Sequential(
|
125 |
+
nn.Linear(d_model, hidden_size, bias=False),
|
126 |
+
nn.GELU(),
|
127 |
+
nn.Linear(hidden_size, output_dim, bias=False),
|
128 |
+
)
|
129 |
+
|
130 |
+
# self.itm_proj = nn.Linear(d_model, 2, bias=False)
|
131 |
+
# self.mlm_proj = nn.Linear(d_model, self.config.vocab_size), bias=False)
|
132 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
133 |
+
|
134 |
+
# def forward_itm(self, x:TensorType, image_embeds:TensorType) -> TensorType:
|
135 |
+
# image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device)
|
136 |
+
# attn_mask = (x != self.config.pad_token_id).long()
|
137 |
+
# out = self.transformer(
|
138 |
+
# input_ids=x,
|
139 |
+
# attention_mask=attn_mask,
|
140 |
+
# encoder_hidden_states = image_embeds,
|
141 |
+
# encoder_attention_mask = image_atts,
|
142 |
+
# )
|
143 |
+
# pooled_out = self.pooler(out, attn_mask)
|
144 |
+
|
145 |
+
# return self.itm_proj(pooled_out)
|
146 |
+
|
147 |
+
def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None):
|
148 |
+
if masked_indices is None:
|
149 |
+
masked_indices = torch.bernoulli(probability_matrix).bool()
|
150 |
+
|
151 |
+
masked_indices[input_ids == self.tokenizer.pad_token_id] = False
|
152 |
+
masked_indices[input_ids == self.tokenizer.cls_token_id] = False
|
153 |
+
|
154 |
+
if targets is not None:
|
155 |
+
targets[~masked_indices] = -100 # We only compute loss on masked tokens
|
156 |
+
|
157 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
158 |
+
indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
|
159 |
+
input_ids[indices_replaced] = self.tokenizer.mask_token_id
|
160 |
+
|
161 |
+
# 10% of the time, we replace masked input tokens with random word
|
162 |
+
indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
163 |
+
random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
|
164 |
+
input_ids[indices_random] = random_words[indices_random]
|
165 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
166 |
+
|
167 |
+
if targets is not None:
|
168 |
+
return input_ids, targets
|
169 |
+
else:
|
170 |
+
return input_ids
|
171 |
+
|
172 |
+
def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25):
|
173 |
+
labels = input_ids.clone()
|
174 |
+
attn_mask = (input_ids != self.config.pad_token_id).long()
|
175 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(input_ids.device)
|
176 |
+
vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"])
|
177 |
+
probability_matrix = torch.full(labels.shape, mlm_probability)
|
178 |
+
input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels, probability_matrix=probability_matrix)
|
179 |
+
mlm_output = self.transformer(
|
180 |
+
input_ids,
|
181 |
+
attention_mask=attn_mask,
|
182 |
+
encoder_hidden_states=image_embeds,
|
183 |
+
encoder_attention_mask=image_atts,
|
184 |
+
return_dict=True,
|
185 |
+
labels=labels,
|
186 |
+
)
|
187 |
+
return mlm_output.loss
|
188 |
+
# mlm_output = self.transformer(input_ids,
|
189 |
+
# attention_mask = attn_mask,
|
190 |
+
# encoder_hidden_states = image_embeds,
|
191 |
+
# encoder_attention_mask = image_atts,
|
192 |
+
# return_dict = True,
|
193 |
+
# ).last_hidden_state
|
194 |
+
# logits = self.mlm_proj(mlm_output)
|
195 |
+
|
196 |
+
# # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size)
|
197 |
+
# logits = logits[:, 1:, :].contiguous().view(-1, vocab_size)
|
198 |
+
# labels = labels[:, 1:].contiguous().view(-1)
|
199 |
+
|
200 |
+
# mlm_loss = F.cross_entropy(
|
201 |
+
# logits,
|
202 |
+
# labels,
|
203 |
+
# # label_smoothing=0.1,
|
204 |
+
# )
|
205 |
+
# return mlm_loss
|
206 |
+
|
207 |
+
def forward(self, x: TensorType) -> TensorType:
|
208 |
+
attn_mask = (x != self.config.pad_token_id).long()
|
209 |
+
out = self.transformer(input_ids=x, attention_mask=attn_mask)
|
210 |
+
pooled_out = self.pooler(out, attn_mask)
|
211 |
+
|
212 |
+
return self.proj(pooled_out)
|
213 |
+
|
214 |
+
def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
215 |
+
if not unlocked_layers: # full freezing
|
216 |
+
for n, p in self.transformer.named_parameters():
|
217 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
218 |
+
return
|
219 |
+
|
220 |
+
encoder = self.transformer.encoder if hasattr(self.transformer, "encoder") else self.transformer
|
221 |
+
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
|
222 |
+
print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
|
223 |
+
embeddings = getattr(self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
|
224 |
+
modules = [embeddings, *layer_list][:-unlocked_layers]
|
225 |
+
# freeze layers
|
226 |
+
for module in modules:
|
227 |
+
for n, p in module.named_parameters():
|
228 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
229 |
+
|
230 |
+
@torch.jit.ignore
|
231 |
+
def set_grad_checkpointing(self, enable=True):
|
232 |
+
self.transformer.gradient_checkpointing_enable()
|
233 |
+
|
234 |
+
def get_num_layers(self):
|
235 |
+
encoder = self.transformer.encoder if hasattr(self.transformer, "encoder") else self.transformer
|
236 |
+
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
|
237 |
+
return len(layer_list)
|
238 |
+
|
239 |
+
def init_parameters(self):
|
240 |
+
pass
|
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/loss.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
try:
|
7 |
+
import torch.distributed.nn
|
8 |
+
from torch import distributed as dist
|
9 |
+
|
10 |
+
has_distributed = True
|
11 |
+
except ImportError:
|
12 |
+
has_distributed = False
|
13 |
+
|
14 |
+
try:
|
15 |
+
import horovod.torch as hvd
|
16 |
+
except ImportError:
|
17 |
+
hvd = None
|
18 |
+
|
19 |
+
from timm.loss import LabelSmoothingCrossEntropy
|
20 |
+
|
21 |
+
|
22 |
+
def gather_features(image_features, text_features, local_loss=False, gather_with_grad=False, rank=0, world_size=1, use_horovod=False):
|
23 |
+
assert has_distributed, "torch.distributed did not import correctly, please use a PyTorch version with support."
|
24 |
+
if use_horovod:
|
25 |
+
assert hvd is not None, "Please install horovod"
|
26 |
+
if gather_with_grad:
|
27 |
+
all_image_features = hvd.allgather(image_features)
|
28 |
+
all_text_features = hvd.allgather(text_features)
|
29 |
+
else:
|
30 |
+
with torch.no_grad():
|
31 |
+
all_image_features = hvd.allgather(image_features)
|
32 |
+
all_text_features = hvd.allgather(text_features)
|
33 |
+
if not local_loss:
|
34 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
35 |
+
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
|
36 |
+
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
|
37 |
+
gathered_image_features[rank] = image_features
|
38 |
+
gathered_text_features[rank] = text_features
|
39 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
40 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
41 |
+
else:
|
42 |
+
# We gather tensors from all gpus
|
43 |
+
if gather_with_grad:
|
44 |
+
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
|
45 |
+
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
|
46 |
+
# all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0)
|
47 |
+
# all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0)
|
48 |
+
else:
|
49 |
+
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
|
50 |
+
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
|
51 |
+
dist.all_gather(gathered_image_features, image_features)
|
52 |
+
dist.all_gather(gathered_text_features, text_features)
|
53 |
+
if not local_loss:
|
54 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
55 |
+
gathered_image_features[rank] = image_features
|
56 |
+
gathered_text_features[rank] = text_features
|
57 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
58 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
59 |
+
|
60 |
+
return all_image_features, all_text_features
|
61 |
+
|
62 |
+
|
63 |
+
class ClipLoss(nn.Module):
|
64 |
+
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
local_loss=False,
|
68 |
+
gather_with_grad=False,
|
69 |
+
cache_labels=False,
|
70 |
+
rank=0,
|
71 |
+
world_size=1,
|
72 |
+
use_horovod=False,
|
73 |
+
smoothing=0.0,
|
74 |
+
):
|
75 |
+
super().__init__()
|
76 |
+
self.local_loss = local_loss
|
77 |
+
self.gather_with_grad = gather_with_grad
|
78 |
+
self.cache_labels = cache_labels
|
79 |
+
self.rank = rank
|
80 |
+
self.world_size = world_size
|
81 |
+
self.use_horovod = use_horovod
|
82 |
+
self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None
|
83 |
+
|
84 |
+
# cache state
|
85 |
+
self.prev_num_logits = 0
|
86 |
+
self.labels = {}
|
87 |
+
|
88 |
+
def forward(self, image_features, text_features, logit_scale=1.0):
|
89 |
+
device = image_features.device
|
90 |
+
if self.world_size > 1:
|
91 |
+
all_image_features, all_text_features = gather_features(image_features, text_features, self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
|
92 |
+
|
93 |
+
if self.local_loss:
|
94 |
+
logits_per_image = logit_scale * image_features @ all_text_features.T
|
95 |
+
logits_per_text = logit_scale * text_features @ all_image_features.T
|
96 |
+
else:
|
97 |
+
logits_per_image = logit_scale * all_image_features @ all_text_features.T
|
98 |
+
logits_per_text = logits_per_image.T
|
99 |
+
else:
|
100 |
+
logits_per_image = logit_scale * image_features @ text_features.T
|
101 |
+
logits_per_text = logit_scale * text_features @ image_features.T
|
102 |
+
# calculated ground-truth and cache if enabled
|
103 |
+
num_logits = logits_per_image.shape[0]
|
104 |
+
if self.prev_num_logits != num_logits or device not in self.labels:
|
105 |
+
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
106 |
+
if self.world_size > 1 and self.local_loss:
|
107 |
+
labels = labels + num_logits * self.rank
|
108 |
+
if self.cache_labels:
|
109 |
+
self.labels[device] = labels
|
110 |
+
self.prev_num_logits = num_logits
|
111 |
+
else:
|
112 |
+
labels = self.labels[device]
|
113 |
+
|
114 |
+
if self.label_smoothing_cross_entropy:
|
115 |
+
total_loss = (self.label_smoothing_cross_entropy(logits_per_image, labels) + self.label_smoothing_cross_entropy(logits_per_text, labels)) / 2
|
116 |
+
else:
|
117 |
+
total_loss = (F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels)) / 2
|
118 |
+
|
119 |
+
acc = None
|
120 |
+
i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image)
|
121 |
+
t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text)
|
122 |
+
acc = {"i2t": i2t_acc, "t2i": t2i_acc}
|
123 |
+
return total_loss, acc
|
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/model.py
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" CLIP Model
|
2 |
+
|
3 |
+
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from typing import Optional, Tuple, Union
|
9 |
+
from functools import partial
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
try:
|
17 |
+
from .hf_model import HFTextEncoder
|
18 |
+
except:
|
19 |
+
HFTextEncoder = None
|
20 |
+
from .modified_resnet import ModifiedResNet
|
21 |
+
from .timm_model import TimmModel
|
22 |
+
from .eva_vit_model import EVAVisionTransformer
|
23 |
+
from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
|
24 |
+
|
25 |
+
try:
|
26 |
+
from apex.normalization import FusedLayerNorm
|
27 |
+
except:
|
28 |
+
FusedLayerNorm = LayerNorm
|
29 |
+
# print("Please 'pip install apex'")
|
30 |
+
|
31 |
+
try:
|
32 |
+
import xformers.ops as xops
|
33 |
+
except ImportError:
|
34 |
+
xops = None
|
35 |
+
# print("Please 'pip install xformers'")
|
36 |
+
|
37 |
+
|
38 |
+
class RMSnorm(nn.Module):
|
39 |
+
"""
|
40 |
+
adepted from transformers T5LayerNorm
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(self, hidden_size, eps=1e-6):
|
44 |
+
"""
|
45 |
+
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
|
46 |
+
"""
|
47 |
+
super().__init__()
|
48 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
49 |
+
self.variance_epsilon = eps
|
50 |
+
|
51 |
+
def forward(self, hidden_states):
|
52 |
+
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
53 |
+
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
|
54 |
+
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
55 |
+
# half-precision inputs is done in fp32
|
56 |
+
|
57 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
58 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
59 |
+
|
60 |
+
# convert into half-precision if necessary
|
61 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
62 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
63 |
+
|
64 |
+
return self.weight * hidden_states
|
65 |
+
|
66 |
+
|
67 |
+
@dataclass
|
68 |
+
class CLIPVisionCfg:
|
69 |
+
layers: Union[Tuple[int, int, int, int], int] = 12
|
70 |
+
width: int = 768
|
71 |
+
head_width: int = 64
|
72 |
+
mlp_ratio: float = 4.0
|
73 |
+
patch_size: int = 16
|
74 |
+
image_size: Union[Tuple[int, int], int] = 224
|
75 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
76 |
+
patch_dropout: float = 0.0 # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
|
77 |
+
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
|
78 |
+
drop_path_rate: Optional[float] = None # drop path rate
|
79 |
+
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
|
80 |
+
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
|
81 |
+
timm_pool: str = "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
|
82 |
+
timm_proj: str = "linear" # linear projection for timm model output ('linear', 'mlp', '')
|
83 |
+
timm_proj_bias: bool = False # enable bias final projection
|
84 |
+
eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
|
85 |
+
qkv_bias: bool = True
|
86 |
+
fusedLN: bool = False
|
87 |
+
xattn: bool = False
|
88 |
+
postnorm: bool = False
|
89 |
+
rope: bool = False
|
90 |
+
pt_hw_seq_len: int = 16 # 224/14
|
91 |
+
intp_freq: bool = False
|
92 |
+
naiveswiglu: bool = False
|
93 |
+
subln: bool = False
|
94 |
+
use_rms_norm: bool = False
|
95 |
+
|
96 |
+
|
97 |
+
@dataclass
|
98 |
+
class CLIPTextCfg:
|
99 |
+
context_length: int = 77
|
100 |
+
vocab_size: int = 49408
|
101 |
+
width: int = 512
|
102 |
+
heads: int = 8
|
103 |
+
layers: int = 12
|
104 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
105 |
+
hf_model_name: str = None
|
106 |
+
hf_tokenizer_name: str = None
|
107 |
+
hf_model_pretrained: bool = True
|
108 |
+
proj: str = "mlp"
|
109 |
+
pooler_type: str = "mean_pooler"
|
110 |
+
masked_language_modeling: bool = False
|
111 |
+
fusedLN: bool = False
|
112 |
+
xattn: bool = False
|
113 |
+
attn_mask: bool = True
|
114 |
+
|
115 |
+
|
116 |
+
def get_cast_dtype(precision: str):
|
117 |
+
cast_dtype = None
|
118 |
+
if precision == "bf16":
|
119 |
+
cast_dtype = torch.bfloat16
|
120 |
+
elif precision == "fp16":
|
121 |
+
cast_dtype = torch.float16
|
122 |
+
return cast_dtype
|
123 |
+
|
124 |
+
|
125 |
+
def _build_vision_tower(embed_dim: int, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None):
|
126 |
+
if isinstance(vision_cfg, dict):
|
127 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg)
|
128 |
+
|
129 |
+
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
|
130 |
+
# memory efficient in recent PyTorch releases (>= 1.10).
|
131 |
+
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
|
132 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
133 |
+
|
134 |
+
if vision_cfg.eva_model_name:
|
135 |
+
vision_heads = vision_cfg.width // vision_cfg.head_width
|
136 |
+
|
137 |
+
norm_layer = RMSnorm if vision_cfg.use_rms_norm else LayerNorm
|
138 |
+
|
139 |
+
visual = EVAVisionTransformer(
|
140 |
+
img_size=vision_cfg.image_size,
|
141 |
+
patch_size=vision_cfg.patch_size,
|
142 |
+
num_classes=embed_dim,
|
143 |
+
use_mean_pooling=vision_cfg.global_average_pool, # False
|
144 |
+
init_values=vision_cfg.ls_init_value,
|
145 |
+
patch_dropout=vision_cfg.patch_dropout,
|
146 |
+
embed_dim=vision_cfg.width,
|
147 |
+
depth=vision_cfg.layers,
|
148 |
+
num_heads=vision_heads,
|
149 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
150 |
+
qkv_bias=vision_cfg.qkv_bias,
|
151 |
+
drop_path_rate=vision_cfg.drop_path_rate,
|
152 |
+
norm_layer=partial(norm_layer, eps=1e-6),
|
153 |
+
xattn=vision_cfg.xattn,
|
154 |
+
rope=vision_cfg.rope,
|
155 |
+
postnorm=vision_cfg.postnorm,
|
156 |
+
pt_hw_seq_len=vision_cfg.pt_hw_seq_len, # 224/14
|
157 |
+
intp_freq=vision_cfg.intp_freq,
|
158 |
+
naiveswiglu=vision_cfg.naiveswiglu,
|
159 |
+
subln=vision_cfg.subln,
|
160 |
+
)
|
161 |
+
elif vision_cfg.timm_model_name:
|
162 |
+
visual = TimmModel(
|
163 |
+
vision_cfg.timm_model_name, pretrained=vision_cfg.timm_model_pretrained, pool=vision_cfg.timm_pool, proj=vision_cfg.timm_proj, proj_bias=vision_cfg.timm_proj_bias, embed_dim=embed_dim, image_size=vision_cfg.image_size
|
164 |
+
)
|
165 |
+
act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
|
166 |
+
elif isinstance(vision_cfg.layers, (tuple, list)):
|
167 |
+
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
|
168 |
+
visual = ModifiedResNet(layers=vision_cfg.layers, output_dim=embed_dim, heads=vision_heads, image_size=vision_cfg.image_size, width=vision_cfg.width)
|
169 |
+
else:
|
170 |
+
vision_heads = vision_cfg.width // vision_cfg.head_width
|
171 |
+
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
172 |
+
visual = VisionTransformer(
|
173 |
+
image_size=vision_cfg.image_size,
|
174 |
+
patch_size=vision_cfg.patch_size,
|
175 |
+
width=vision_cfg.width,
|
176 |
+
layers=vision_cfg.layers,
|
177 |
+
heads=vision_heads,
|
178 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
179 |
+
ls_init_value=vision_cfg.ls_init_value,
|
180 |
+
patch_dropout=vision_cfg.patch_dropout,
|
181 |
+
global_average_pool=vision_cfg.global_average_pool,
|
182 |
+
output_dim=embed_dim,
|
183 |
+
act_layer=act_layer,
|
184 |
+
norm_layer=norm_layer,
|
185 |
+
)
|
186 |
+
|
187 |
+
return visual
|
188 |
+
|
189 |
+
|
190 |
+
def _build_text_tower(
|
191 |
+
embed_dim: int,
|
192 |
+
text_cfg: CLIPTextCfg,
|
193 |
+
quick_gelu: bool = False,
|
194 |
+
cast_dtype: Optional[torch.dtype] = None,
|
195 |
+
):
|
196 |
+
if isinstance(text_cfg, dict):
|
197 |
+
text_cfg = CLIPTextCfg(**text_cfg)
|
198 |
+
|
199 |
+
if text_cfg.hf_model_name:
|
200 |
+
text = HFTextEncoder(text_cfg.hf_model_name, output_dim=embed_dim, tokenizer_name=text_cfg.hf_tokenizer_name, proj=text_cfg.proj, pooler_type=text_cfg.pooler_type, masked_language_modeling=text_cfg.masked_language_modeling)
|
201 |
+
else:
|
202 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
203 |
+
norm_layer = LayerNorm
|
204 |
+
|
205 |
+
text = TextTransformer(
|
206 |
+
context_length=text_cfg.context_length,
|
207 |
+
vocab_size=text_cfg.vocab_size,
|
208 |
+
width=text_cfg.width,
|
209 |
+
heads=text_cfg.heads,
|
210 |
+
layers=text_cfg.layers,
|
211 |
+
ls_init_value=text_cfg.ls_init_value,
|
212 |
+
output_dim=embed_dim,
|
213 |
+
act_layer=act_layer,
|
214 |
+
norm_layer=FusedLayerNorm if text_cfg.fusedLN else norm_layer,
|
215 |
+
xattn=text_cfg.xattn,
|
216 |
+
attn_mask=text_cfg.attn_mask,
|
217 |
+
)
|
218 |
+
return text
|
219 |
+
|
220 |
+
|
221 |
+
class CLIP(nn.Module):
|
222 |
+
def __init__(
|
223 |
+
self,
|
224 |
+
embed_dim: int,
|
225 |
+
vision_cfg: CLIPVisionCfg,
|
226 |
+
text_cfg: CLIPTextCfg,
|
227 |
+
quick_gelu: bool = False,
|
228 |
+
cast_dtype: Optional[torch.dtype] = None,
|
229 |
+
):
|
230 |
+
super().__init__()
|
231 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
232 |
+
|
233 |
+
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
234 |
+
self.transformer = text.transformer
|
235 |
+
self.vocab_size = text.vocab_size
|
236 |
+
self.token_embedding = text.token_embedding
|
237 |
+
self.positional_embedding = text.positional_embedding
|
238 |
+
self.ln_final = text.ln_final
|
239 |
+
self.text_projection = text.text_projection
|
240 |
+
self.register_buffer("attn_mask", text.attn_mask, persistent=False)
|
241 |
+
|
242 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
243 |
+
|
244 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
245 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
246 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
247 |
+
|
248 |
+
@torch.jit.ignore
|
249 |
+
def set_grad_checkpointing(self, enable=True):
|
250 |
+
self.visual.set_grad_checkpointing(enable)
|
251 |
+
self.transformer.grad_checkpointing = enable
|
252 |
+
|
253 |
+
@torch.jit.ignore
|
254 |
+
def no_weight_decay(self):
|
255 |
+
return {"logit_scale"}
|
256 |
+
|
257 |
+
def encode_image(self, image, normalize: bool = False):
|
258 |
+
features = self.visual(image)
|
259 |
+
return F.normalize(features, dim=-1) if normalize else features
|
260 |
+
|
261 |
+
def encode_text(self, text, normalize: bool = False):
|
262 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
263 |
+
|
264 |
+
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
265 |
+
|
266 |
+
x = x + self.positional_embedding.to(cast_dtype)
|
267 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
268 |
+
x = self.transformer(x, attn_mask=self.attn_mask)
|
269 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
270 |
+
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
|
271 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
272 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
273 |
+
return F.normalize(x, dim=-1) if normalize else x
|
274 |
+
|
275 |
+
def forward(self, image, text):
|
276 |
+
image_features = self.encode_image(image, normalize=True)
|
277 |
+
text_features = self.encode_text(text, normalize=True)
|
278 |
+
return image_features, text_features, self.logit_scale.exp()
|
279 |
+
|
280 |
+
|
281 |
+
class CustomCLIP(nn.Module):
|
282 |
+
def __init__(
|
283 |
+
self,
|
284 |
+
embed_dim: int,
|
285 |
+
vision_cfg: CLIPVisionCfg,
|
286 |
+
text_cfg: CLIPTextCfg,
|
287 |
+
quick_gelu: bool = False,
|
288 |
+
cast_dtype: Optional[torch.dtype] = None,
|
289 |
+
itm_task: bool = False,
|
290 |
+
):
|
291 |
+
super().__init__()
|
292 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
293 |
+
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
294 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
295 |
+
|
296 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
297 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
298 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
299 |
+
|
300 |
+
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
301 |
+
self.text.lock(unlocked_layers, freeze_layer_norm)
|
302 |
+
|
303 |
+
@torch.jit.ignore
|
304 |
+
def set_grad_checkpointing(self, enable=True):
|
305 |
+
self.visual.set_grad_checkpointing(enable)
|
306 |
+
self.text.set_grad_checkpointing(enable)
|
307 |
+
|
308 |
+
@torch.jit.ignore
|
309 |
+
def no_weight_decay(self):
|
310 |
+
return {"logit_scale"}
|
311 |
+
|
312 |
+
def encode_image(self, image, normalize: bool = False):
|
313 |
+
features = self.visual(image)
|
314 |
+
return F.normalize(features, dim=-1) if normalize else features
|
315 |
+
|
316 |
+
def encode_text(self, text, normalize: bool = False):
|
317 |
+
features = self.text(text)
|
318 |
+
return F.normalize(features, dim=-1) if normalize else features
|
319 |
+
|
320 |
+
def forward(self, image, text):
|
321 |
+
image_features = self.encode_image(image, normalize=True)
|
322 |
+
text_features = self.encode_text(text, normalize=True)
|
323 |
+
return image_features, text_features, self.logit_scale.exp()
|
324 |
+
|
325 |
+
|
326 |
+
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
|
327 |
+
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
|
328 |
+
|
329 |
+
def _convert_weights(l):
|
330 |
+
|
331 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
332 |
+
l.weight.data = l.weight.data.to(dtype)
|
333 |
+
if l.bias is not None:
|
334 |
+
l.bias.data = l.bias.data.to(dtype)
|
335 |
+
|
336 |
+
if isinstance(l, (nn.MultiheadAttention, Attention)):
|
337 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
338 |
+
tensor = getattr(l, attr, None)
|
339 |
+
if tensor is not None:
|
340 |
+
tensor.data = tensor.data.to(dtype)
|
341 |
+
|
342 |
+
if isinstance(l, nn.Parameter):
|
343 |
+
l.data = l.data.to(dtype)
|
344 |
+
|
345 |
+
for name in ["text_projection", "proj"]:
|
346 |
+
if hasattr(l, name) and isinstance(l, nn.Parameter):
|
347 |
+
attr = getattr(l, name, None)
|
348 |
+
if attr is not None:
|
349 |
+
attr.data = attr.data.to(dtype)
|
350 |
+
|
351 |
+
model.apply(_convert_weights)
|
352 |
+
|
353 |
+
|
354 |
+
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
|
355 |
+
|
356 |
+
|
357 |
+
# used to maintain checkpoint compatibility
|
358 |
+
def convert_to_custom_text_state_dict(state_dict: dict):
|
359 |
+
if "text_projection" in state_dict:
|
360 |
+
# old format state_dict, move text tower -> .text
|
361 |
+
new_state_dict = {}
|
362 |
+
for k, v in state_dict.items():
|
363 |
+
if any(k.startswith(p) for p in ("text_projection", "positional_embedding", "token_embedding", "transformer", "ln_final", "logit_scale")):
|
364 |
+
k = "text." + k
|
365 |
+
new_state_dict[k] = v
|
366 |
+
return new_state_dict
|
367 |
+
return state_dict
|
368 |
+
|
369 |
+
|
370 |
+
def build_model_from_openai_state_dict(
|
371 |
+
state_dict: dict,
|
372 |
+
quick_gelu=True,
|
373 |
+
cast_dtype=torch.float16,
|
374 |
+
):
|
375 |
+
vit = "visual.proj" in state_dict
|
376 |
+
|
377 |
+
if vit:
|
378 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
379 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
380 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
381 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
382 |
+
image_size = vision_patch_size * grid_size
|
383 |
+
else:
|
384 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
385 |
+
vision_layers = tuple(counts)
|
386 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
387 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
388 |
+
vision_patch_size = None
|
389 |
+
assert output_width**2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
390 |
+
image_size = output_width * 32
|
391 |
+
|
392 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
393 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
394 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
395 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
396 |
+
transformer_heads = transformer_width // 64
|
397 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
398 |
+
|
399 |
+
vision_cfg = CLIPVisionCfg(
|
400 |
+
layers=vision_layers,
|
401 |
+
width=vision_width,
|
402 |
+
patch_size=vision_patch_size,
|
403 |
+
image_size=image_size,
|
404 |
+
)
|
405 |
+
text_cfg = CLIPTextCfg(context_length=context_length, vocab_size=vocab_size, width=transformer_width, heads=transformer_heads, layers=transformer_layers)
|
406 |
+
model = CLIP(
|
407 |
+
embed_dim,
|
408 |
+
vision_cfg=vision_cfg,
|
409 |
+
text_cfg=text_cfg,
|
410 |
+
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
|
411 |
+
cast_dtype=cast_dtype,
|
412 |
+
)
|
413 |
+
|
414 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
415 |
+
state_dict.pop(key, None)
|
416 |
+
|
417 |
+
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
|
418 |
+
model.load_state_dict(state_dict)
|
419 |
+
return model.eval()
|
420 |
+
|
421 |
+
|
422 |
+
def trace_model(model, batch_size=256, device=torch.device("cpu")):
|
423 |
+
model.eval()
|
424 |
+
image_size = model.visual.image_size
|
425 |
+
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
|
426 |
+
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
|
427 |
+
model = torch.jit.trace_module(model, inputs=dict(forward=(example_images, example_text), encode_text=(example_text,), encode_image=(example_images,)))
|
428 |
+
model.visual.image_size = image_size
|
429 |
+
return model
|
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/modified_resnet.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from .utils import freeze_batch_norm_2d
|
8 |
+
|
9 |
+
|
10 |
+
class Bottleneck(nn.Module):
|
11 |
+
expansion = 4
|
12 |
+
|
13 |
+
def __init__(self, inplanes, planes, stride=1):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
17 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
18 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
19 |
+
self.act1 = nn.ReLU(inplace=True)
|
20 |
+
|
21 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
22 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
23 |
+
self.act2 = nn.ReLU(inplace=True)
|
24 |
+
|
25 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
26 |
+
|
27 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
28 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
29 |
+
self.act3 = nn.ReLU(inplace=True)
|
30 |
+
|
31 |
+
self.downsample = None
|
32 |
+
self.stride = stride
|
33 |
+
|
34 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
35 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
36 |
+
self.downsample = nn.Sequential(OrderedDict([("-1", nn.AvgPool2d(stride)), ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), ("1", nn.BatchNorm2d(planes * self.expansion))]))
|
37 |
+
|
38 |
+
def forward(self, x: torch.Tensor):
|
39 |
+
identity = x
|
40 |
+
|
41 |
+
out = self.act1(self.bn1(self.conv1(x)))
|
42 |
+
out = self.act2(self.bn2(self.conv2(out)))
|
43 |
+
out = self.avgpool(out)
|
44 |
+
out = self.bn3(self.conv3(out))
|
45 |
+
|
46 |
+
if self.downsample is not None:
|
47 |
+
identity = self.downsample(x)
|
48 |
+
|
49 |
+
out += identity
|
50 |
+
out = self.act3(out)
|
51 |
+
return out
|
52 |
+
|
53 |
+
|
54 |
+
class AttentionPool2d(nn.Module):
|
55 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
56 |
+
super().__init__()
|
57 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
|
58 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
59 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
60 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
61 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
62 |
+
self.num_heads = num_heads
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
66 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
67 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
68 |
+
x, _ = F.multi_head_attention_forward(
|
69 |
+
query=x,
|
70 |
+
key=x,
|
71 |
+
value=x,
|
72 |
+
embed_dim_to_check=x.shape[-1],
|
73 |
+
num_heads=self.num_heads,
|
74 |
+
q_proj_weight=self.q_proj.weight,
|
75 |
+
k_proj_weight=self.k_proj.weight,
|
76 |
+
v_proj_weight=self.v_proj.weight,
|
77 |
+
in_proj_weight=None,
|
78 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
79 |
+
bias_k=None,
|
80 |
+
bias_v=None,
|
81 |
+
add_zero_attn=False,
|
82 |
+
dropout_p=0.0,
|
83 |
+
out_proj_weight=self.c_proj.weight,
|
84 |
+
out_proj_bias=self.c_proj.bias,
|
85 |
+
use_separate_proj_weight=True,
|
86 |
+
training=self.training,
|
87 |
+
need_weights=False,
|
88 |
+
)
|
89 |
+
|
90 |
+
return x[0]
|
91 |
+
|
92 |
+
|
93 |
+
class ModifiedResNet(nn.Module):
|
94 |
+
"""
|
95 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
96 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
97 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
98 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
99 |
+
"""
|
100 |
+
|
101 |
+
def __init__(self, layers, output_dim, heads, image_size=224, width=64):
|
102 |
+
super().__init__()
|
103 |
+
self.output_dim = output_dim
|
104 |
+
self.image_size = image_size
|
105 |
+
|
106 |
+
# the 3-layer stem
|
107 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
108 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
109 |
+
self.act1 = nn.ReLU(inplace=True)
|
110 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
111 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
112 |
+
self.act2 = nn.ReLU(inplace=True)
|
113 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
114 |
+
self.bn3 = nn.BatchNorm2d(width)
|
115 |
+
self.act3 = nn.ReLU(inplace=True)
|
116 |
+
self.avgpool = nn.AvgPool2d(2)
|
117 |
+
|
118 |
+
# residual layers
|
119 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
120 |
+
self.layer1 = self._make_layer(width, layers[0])
|
121 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
122 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
123 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
124 |
+
|
125 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
126 |
+
self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
|
127 |
+
|
128 |
+
self.init_parameters()
|
129 |
+
|
130 |
+
def _make_layer(self, planes, blocks, stride=1):
|
131 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
132 |
+
|
133 |
+
self._inplanes = planes * Bottleneck.expansion
|
134 |
+
for _ in range(1, blocks):
|
135 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
136 |
+
|
137 |
+
return nn.Sequential(*layers)
|
138 |
+
|
139 |
+
def init_parameters(self):
|
140 |
+
if self.attnpool is not None:
|
141 |
+
std = self.attnpool.c_proj.in_features**-0.5
|
142 |
+
nn.init.normal_(self.attnpool.q_proj.weight, std=std)
|
143 |
+
nn.init.normal_(self.attnpool.k_proj.weight, std=std)
|
144 |
+
nn.init.normal_(self.attnpool.v_proj.weight, std=std)
|
145 |
+
nn.init.normal_(self.attnpool.c_proj.weight, std=std)
|
146 |
+
|
147 |
+
for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
|
148 |
+
for name, param in resnet_block.named_parameters():
|
149 |
+
if name.endswith("bn3.weight"):
|
150 |
+
nn.init.zeros_(param)
|
151 |
+
|
152 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
153 |
+
assert unlocked_groups == 0, "partial locking not currently supported for this model"
|
154 |
+
for param in self.parameters():
|
155 |
+
param.requires_grad = False
|
156 |
+
if freeze_bn_stats:
|
157 |
+
freeze_batch_norm_2d(self)
|
158 |
+
|
159 |
+
@torch.jit.ignore
|
160 |
+
def set_grad_checkpointing(self, enable=True):
|
161 |
+
# FIXME support for non-transformer
|
162 |
+
pass
|
163 |
+
|
164 |
+
def stem(self, x):
|
165 |
+
x = self.act1(self.bn1(self.conv1(x)))
|
166 |
+
x = self.act2(self.bn2(self.conv2(x)))
|
167 |
+
x = self.act3(self.bn3(self.conv3(x)))
|
168 |
+
x = self.avgpool(x)
|
169 |
+
return x
|
170 |
+
|
171 |
+
def forward(self, x):
|
172 |
+
x = self.stem(x)
|
173 |
+
x = self.layer1(x)
|
174 |
+
x = self.layer2(x)
|
175 |
+
x = self.layer3(x)
|
176 |
+
x = self.layer4(x)
|
177 |
+
x = self.attnpool(x)
|
178 |
+
|
179 |
+
return x
|
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/openai.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" OpenAI pretrained model functions
|
2 |
+
|
3 |
+
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import warnings
|
8 |
+
from typing import List, Optional, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
|
13 |
+
from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
|
14 |
+
|
15 |
+
__all__ = ["list_openai_models", "load_openai_model"]
|
16 |
+
|
17 |
+
|
18 |
+
def list_openai_models() -> List[str]:
|
19 |
+
"""Returns the names of available CLIP models"""
|
20 |
+
return list_pretrained_models_by_tag("openai")
|
21 |
+
|
22 |
+
|
23 |
+
def load_openai_model(
|
24 |
+
name: str,
|
25 |
+
precision: Optional[str] = None,
|
26 |
+
device: Optional[Union[str, torch.device]] = None,
|
27 |
+
jit: bool = True,
|
28 |
+
cache_dir: Optional[str] = None,
|
29 |
+
):
|
30 |
+
"""Load a CLIP model
|
31 |
+
|
32 |
+
Parameters
|
33 |
+
----------
|
34 |
+
name : str
|
35 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
36 |
+
precision: str
|
37 |
+
Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
|
38 |
+
device : Union[str, torch.device]
|
39 |
+
The device to put the loaded model
|
40 |
+
jit : bool
|
41 |
+
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
|
42 |
+
cache_dir : Optional[str]
|
43 |
+
The directory to cache the downloaded model weights
|
44 |
+
|
45 |
+
Returns
|
46 |
+
-------
|
47 |
+
model : torch.nn.Module
|
48 |
+
The CLIP model
|
49 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
50 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
51 |
+
"""
|
52 |
+
if device is None:
|
53 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
54 |
+
if precision is None:
|
55 |
+
precision = "fp32" if device == "cpu" else "fp16"
|
56 |
+
|
57 |
+
if get_pretrained_url(name, "openai"):
|
58 |
+
model_path = download_pretrained_from_url(get_pretrained_url(name, "openai"), cache_dir=cache_dir)
|
59 |
+
elif os.path.isfile(name):
|
60 |
+
model_path = name
|
61 |
+
else:
|
62 |
+
raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
|
63 |
+
|
64 |
+
try:
|
65 |
+
# loading JIT archive
|
66 |
+
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
67 |
+
state_dict = None
|
68 |
+
except RuntimeError:
|
69 |
+
# loading saved state dict
|
70 |
+
if jit:
|
71 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
72 |
+
jit = False
|
73 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
74 |
+
|
75 |
+
if not jit:
|
76 |
+
# Build a non-jit model from the OpenAI jitted model state dict
|
77 |
+
cast_dtype = get_cast_dtype(precision)
|
78 |
+
try:
|
79 |
+
model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
|
80 |
+
except KeyError:
|
81 |
+
sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
|
82 |
+
model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
|
83 |
+
|
84 |
+
# model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
|
85 |
+
model = model.to(device)
|
86 |
+
if precision.startswith("amp") or precision == "fp32":
|
87 |
+
model.float()
|
88 |
+
elif precision == "bf16":
|
89 |
+
convert_weights_to_lp(model, dtype=torch.bfloat16)
|
90 |
+
|
91 |
+
return model
|
92 |
+
|
93 |
+
# patch the device names
|
94 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
95 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
96 |
+
|
97 |
+
def patch_device(module):
|
98 |
+
try:
|
99 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
100 |
+
except RuntimeError:
|
101 |
+
graphs = []
|
102 |
+
|
103 |
+
if hasattr(module, "forward1"):
|
104 |
+
graphs.append(module.forward1.graph)
|
105 |
+
|
106 |
+
for graph in graphs:
|
107 |
+
for node in graph.findAllNodes("prim::Constant"):
|
108 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
109 |
+
node.copyAttributes(device_node)
|
110 |
+
|
111 |
+
model.apply(patch_device)
|
112 |
+
patch_device(model.encode_image)
|
113 |
+
patch_device(model.encode_text)
|
114 |
+
|
115 |
+
# patch dtype to float32 (typically for CPU)
|
116 |
+
if precision == "fp32":
|
117 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
118 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
119 |
+
float_node = float_input.node()
|
120 |
+
|
121 |
+
def patch_float(module):
|
122 |
+
try:
|
123 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
124 |
+
except RuntimeError:
|
125 |
+
graphs = []
|
126 |
+
|
127 |
+
if hasattr(module, "forward1"):
|
128 |
+
graphs.append(module.forward1.graph)
|
129 |
+
|
130 |
+
for graph in graphs:
|
131 |
+
for node in graph.findAllNodes("aten::to"):
|
132 |
+
inputs = list(node.inputs())
|
133 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
134 |
+
if inputs[i].node()["value"] == 5:
|
135 |
+
inputs[i].node().copyAttributes(float_node)
|
136 |
+
|
137 |
+
model.apply(patch_float)
|
138 |
+
patch_float(model.encode_image)
|
139 |
+
patch_float(model.encode_text)
|
140 |
+
model.float()
|
141 |
+
|
142 |
+
# ensure image_size attr available at consistent location for both jit and non-jit
|
143 |
+
model.visual.image_size = model.input_resolution.item()
|
144 |
+
return model
|
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/pretrained.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
import urllib
|
4 |
+
import warnings
|
5 |
+
from typing import Dict, Union
|
6 |
+
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
try:
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
|
12 |
+
_has_hf_hub = True
|
13 |
+
except ImportError:
|
14 |
+
hf_hub_download = None
|
15 |
+
_has_hf_hub = False
|
16 |
+
|
17 |
+
|
18 |
+
def _pcfg(url="", hf_hub="", filename="", mean=None, std=None):
|
19 |
+
return dict(
|
20 |
+
url=url,
|
21 |
+
hf_hub=hf_hub,
|
22 |
+
mean=mean,
|
23 |
+
std=std,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
_VITB32 = dict(
|
28 |
+
openai=_pcfg("https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
|
29 |
+
laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
|
30 |
+
laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
|
31 |
+
laion2b_e16=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
|
32 |
+
laion2b_s34b_b79k=_pcfg(hf_hub="laion/CLIP-ViT-B-32-laion2B-s34B-b79K/"),
|
33 |
+
)
|
34 |
+
|
35 |
+
_VITB32_quickgelu = dict(
|
36 |
+
openai=_pcfg("https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
|
37 |
+
laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
|
38 |
+
laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
|
39 |
+
)
|
40 |
+
|
41 |
+
_VITB16 = dict(
|
42 |
+
openai=_pcfg("https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
|
43 |
+
laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
|
44 |
+
laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
|
45 |
+
laion2b_s34b_b88k=_pcfg(hf_hub="laion/CLIP-ViT-B-16-laion2B-s34B-b88K/"),
|
46 |
+
)
|
47 |
+
|
48 |
+
_EVAB16 = dict(
|
49 |
+
eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt"),
|
50 |
+
eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt"),
|
51 |
+
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt"),
|
52 |
+
eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt"),
|
53 |
+
)
|
54 |
+
|
55 |
+
_VITB16_PLUS_240 = dict(
|
56 |
+
laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
|
57 |
+
laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
|
58 |
+
)
|
59 |
+
|
60 |
+
_VITL14 = dict(
|
61 |
+
openai=_pcfg("https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
|
62 |
+
laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
|
63 |
+
laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
|
64 |
+
laion2b_s32b_b82k=_pcfg(hf_hub="laion/CLIP-ViT-L-14-laion2B-s32B-b82K/", mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
65 |
+
)
|
66 |
+
|
67 |
+
_EVAL14 = dict(
|
68 |
+
eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_L_psz14.pt"),
|
69 |
+
eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_L_psz14.pt"),
|
70 |
+
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt"),
|
71 |
+
eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt"),
|
72 |
+
)
|
73 |
+
|
74 |
+
_VITL14_336 = dict(
|
75 |
+
openai=_pcfg("https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
|
76 |
+
)
|
77 |
+
|
78 |
+
_EVAL14_336 = dict(
|
79 |
+
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt"),
|
80 |
+
eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt"),
|
81 |
+
eva_clip_224to336=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt"),
|
82 |
+
eva02_clip_224to336=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt"),
|
83 |
+
)
|
84 |
+
|
85 |
+
_VITH14 = dict(
|
86 |
+
laion2b_s32b_b79k=_pcfg(hf_hub="laion/CLIP-ViT-H-14-laion2B-s32B-b79K/"),
|
87 |
+
)
|
88 |
+
|
89 |
+
_VITg14 = dict(
|
90 |
+
laion2b_s12b_b42k=_pcfg(hf_hub="laion/CLIP-ViT-g-14-laion2B-s12B-b42K/"),
|
91 |
+
laion2b_s34b_b88k=_pcfg(hf_hub="laion/CLIP-ViT-g-14-laion2B-s34B-b88K/"),
|
92 |
+
)
|
93 |
+
|
94 |
+
_EVAg14 = dict(
|
95 |
+
eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/"),
|
96 |
+
eva01=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_g_psz14.pt"),
|
97 |
+
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt"),
|
98 |
+
eva01_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt"),
|
99 |
+
)
|
100 |
+
|
101 |
+
_EVAg14_PLUS = dict(
|
102 |
+
eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/"),
|
103 |
+
eva01=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_g_psz14.pt"),
|
104 |
+
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt"),
|
105 |
+
eva01_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt"),
|
106 |
+
)
|
107 |
+
|
108 |
+
_VITbigG14 = dict(
|
109 |
+
laion2b_s39b_b160k=_pcfg(hf_hub="laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/"),
|
110 |
+
)
|
111 |
+
|
112 |
+
_EVAbigE14 = dict(
|
113 |
+
eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"),
|
114 |
+
eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"),
|
115 |
+
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt"),
|
116 |
+
eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt"),
|
117 |
+
)
|
118 |
+
|
119 |
+
_EVAbigE14_PLUS = dict(
|
120 |
+
eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"),
|
121 |
+
eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"),
|
122 |
+
eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt"),
|
123 |
+
eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt"),
|
124 |
+
)
|
125 |
+
|
126 |
+
_EVA_8B = dict(
|
127 |
+
eva=_pcfg(hf_hub="BAAI/EVA-CLIP-8B/EVA_8B_psz14.bin"),
|
128 |
+
eva_clip=_pcfg(hf_hub="BAAI/EVA-CLIP-8B/EVA_CLIP_8B_psz14_s9B.pt"),
|
129 |
+
)
|
130 |
+
|
131 |
+
_EVA_8B_PLUS = dict(
|
132 |
+
eva_clip=_pcfg(hf_hub="BAAI/EVA-CLIP-8B-448/EVA_CLIP_8B_psz14_plus_s0.6B.pt"),
|
133 |
+
)
|
134 |
+
|
135 |
+
|
136 |
+
_PRETRAINED = {
|
137 |
+
# "ViT-B-32": _VITB32,
|
138 |
+
"OpenaiCLIP-B-32": _VITB32,
|
139 |
+
"OpenCLIP-B-32": _VITB32,
|
140 |
+
# "ViT-B-32-quickgelu": _VITB32_quickgelu,
|
141 |
+
"OpenaiCLIP-B-32-quickgelu": _VITB32_quickgelu,
|
142 |
+
"OpenCLIP-B-32-quickgelu": _VITB32_quickgelu,
|
143 |
+
# "ViT-B-16": _VITB16,
|
144 |
+
"OpenaiCLIP-B-16": _VITB16,
|
145 |
+
"OpenCLIP-B-16": _VITB16,
|
146 |
+
"EVA02-B-16": _EVAB16,
|
147 |
+
"EVA02-CLIP-B-16": _EVAB16,
|
148 |
+
# "ViT-B-16-plus-240": _VITB16_PLUS_240,
|
149 |
+
"OpenCLIP-B-16-plus-240": _VITB16_PLUS_240,
|
150 |
+
# "ViT-L-14": _VITL14,
|
151 |
+
"OpenaiCLIP-L-14": _VITL14,
|
152 |
+
"OpenCLIP-L-14": _VITL14,
|
153 |
+
"EVA02-L-14": _EVAL14,
|
154 |
+
"EVA02-CLIP-L-14": _EVAL14,
|
155 |
+
# "ViT-L-14-336": _VITL14_336,
|
156 |
+
"OpenaiCLIP-L-14-336": _VITL14_336,
|
157 |
+
"EVA02-CLIP-L-14-336": _EVAL14_336,
|
158 |
+
# "ViT-H-14": _VITH14,
|
159 |
+
# "ViT-g-14": _VITg14,
|
160 |
+
"OpenCLIP-H-14": _VITH14,
|
161 |
+
"OpenCLIP-g-14": _VITg14,
|
162 |
+
"EVA01-CLIP-g-14": _EVAg14,
|
163 |
+
"EVA01-CLIP-g-14-plus": _EVAg14_PLUS,
|
164 |
+
# "ViT-bigG-14": _VITbigG14,
|
165 |
+
"OpenCLIP-bigG-14": _VITbigG14,
|
166 |
+
"EVA02-CLIP-bigE-14": _EVAbigE14,
|
167 |
+
"EVA02-CLIP-bigE-14-plus": _EVAbigE14_PLUS,
|
168 |
+
"EVA-CLIP-8B": _EVA_8B,
|
169 |
+
"EVA-CLIP-8B-448": _EVA_8B_PLUS,
|
170 |
+
"EVA-CLIP-8B-plus": _EVA_8B_PLUS,
|
171 |
+
}
|
172 |
+
|
173 |
+
|
174 |
+
def _clean_tag(tag: str):
|
175 |
+
# normalize pretrained tags
|
176 |
+
return tag.lower().replace("-", "_")
|
177 |
+
|
178 |
+
|
179 |
+
def list_pretrained(as_str: bool = False):
|
180 |
+
"""returns list of pretrained models
|
181 |
+
Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
|
182 |
+
"""
|
183 |
+
return [":".join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
|
184 |
+
|
185 |
+
|
186 |
+
def list_pretrained_models_by_tag(tag: str):
|
187 |
+
"""return all models having the specified pretrain tag"""
|
188 |
+
models = []
|
189 |
+
tag = _clean_tag(tag)
|
190 |
+
for k in _PRETRAINED.keys():
|
191 |
+
if tag in _PRETRAINED[k]:
|
192 |
+
models.append(k)
|
193 |
+
return models
|
194 |
+
|
195 |
+
|
196 |
+
def list_pretrained_tags_by_model(model: str):
|
197 |
+
"""return all pretrain tags for the specified model architecture"""
|
198 |
+
tags = []
|
199 |
+
if model in _PRETRAINED:
|
200 |
+
tags.extend(_PRETRAINED[model].keys())
|
201 |
+
return tags
|
202 |
+
|
203 |
+
|
204 |
+
def is_pretrained_cfg(model: str, tag: str):
|
205 |
+
if model not in _PRETRAINED:
|
206 |
+
return False
|
207 |
+
return _clean_tag(tag) in _PRETRAINED[model]
|
208 |
+
|
209 |
+
|
210 |
+
def get_pretrained_cfg(model: str, tag: str):
|
211 |
+
if model not in _PRETRAINED:
|
212 |
+
return {}
|
213 |
+
model_pretrained = _PRETRAINED[model]
|
214 |
+
return model_pretrained.get(_clean_tag(tag), {})
|
215 |
+
|
216 |
+
|
217 |
+
def get_pretrained_url(model: str, tag: str):
|
218 |
+
cfg = get_pretrained_cfg(model, _clean_tag(tag))
|
219 |
+
return cfg.get("url", "")
|
220 |
+
|
221 |
+
|
222 |
+
def download_pretrained_from_url(
|
223 |
+
url: str,
|
224 |
+
cache_dir: Union[str, None] = None,
|
225 |
+
):
|
226 |
+
if not cache_dir:
|
227 |
+
cache_dir = os.path.expanduser("~/.cache/clip")
|
228 |
+
os.makedirs(cache_dir, exist_ok=True)
|
229 |
+
filename = os.path.basename(url)
|
230 |
+
|
231 |
+
if "openaipublic" in url:
|
232 |
+
expected_sha256 = url.split("/")[-2]
|
233 |
+
elif "mlfoundations" in url:
|
234 |
+
expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
|
235 |
+
else:
|
236 |
+
expected_sha256 = ""
|
237 |
+
|
238 |
+
download_target = os.path.join(cache_dir, filename)
|
239 |
+
|
240 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
241 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
242 |
+
|
243 |
+
if os.path.isfile(download_target):
|
244 |
+
if expected_sha256:
|
245 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
|
246 |
+
return download_target
|
247 |
+
else:
|
248 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
249 |
+
else:
|
250 |
+
return download_target
|
251 |
+
|
252 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
253 |
+
with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit="iB", unit_scale=True) as loop:
|
254 |
+
while True:
|
255 |
+
buffer = source.read(8192)
|
256 |
+
if not buffer:
|
257 |
+
break
|
258 |
+
|
259 |
+
output.write(buffer)
|
260 |
+
loop.update(len(buffer))
|
261 |
+
|
262 |
+
if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
|
263 |
+
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
264 |
+
|
265 |
+
return download_target
|
266 |
+
|
267 |
+
|
268 |
+
def has_hf_hub(necessary=False):
|
269 |
+
if not _has_hf_hub and necessary:
|
270 |
+
# if no HF Hub module installed, and it is necessary to continue, raise error
|
271 |
+
raise RuntimeError("Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.")
|
272 |
+
return _has_hf_hub
|
273 |
+
|
274 |
+
|
275 |
+
def download_pretrained_from_hf(
|
276 |
+
model_id: str,
|
277 |
+
filename: str = "open_clip_pytorch_model.bin",
|
278 |
+
revision=None,
|
279 |
+
cache_dir: Union[str, None] = None,
|
280 |
+
):
|
281 |
+
has_hf_hub(True)
|
282 |
+
cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
|
283 |
+
return cached_file
|
284 |
+
|
285 |
+
|
286 |
+
def download_pretrained(
|
287 |
+
cfg: Dict,
|
288 |
+
force_hf_hub: bool = False,
|
289 |
+
cache_dir: Union[str, None] = None,
|
290 |
+
):
|
291 |
+
target = ""
|
292 |
+
if not cfg:
|
293 |
+
return target
|
294 |
+
|
295 |
+
download_url = cfg.get("url", "")
|
296 |
+
download_hf_hub = cfg.get("hf_hub", "")
|
297 |
+
if download_hf_hub and force_hf_hub:
|
298 |
+
# use HF hub even if url exists
|
299 |
+
download_url = ""
|
300 |
+
|
301 |
+
if download_url:
|
302 |
+
target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
|
303 |
+
elif download_hf_hub:
|
304 |
+
has_hf_hub(True)
|
305 |
+
# we assume the hf_hub entries in pretrained config combine model_id + filename in
|
306 |
+
# 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
|
307 |
+
# use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
|
308 |
+
model_id, filename = os.path.split(download_hf_hub)
|
309 |
+
if filename:
|
310 |
+
target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
|
311 |
+
else:
|
312 |
+
target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
|
313 |
+
|
314 |
+
return target
|
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/rope.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import pi
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
import logging
|
6 |
+
|
7 |
+
|
8 |
+
def broadcat(tensors, dim=-1):
|
9 |
+
num_tensors = len(tensors)
|
10 |
+
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
11 |
+
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
12 |
+
shape_len = list(shape_lens)[0]
|
13 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
14 |
+
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
15 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
16 |
+
assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), "invalid dimensions for broadcastable concatentation"
|
17 |
+
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
18 |
+
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
19 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
20 |
+
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
21 |
+
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
22 |
+
return torch.cat(tensors, dim=dim)
|
23 |
+
|
24 |
+
|
25 |
+
def rotate_half(x):
|
26 |
+
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
27 |
+
x1, x2 = x.unbind(dim=-1)
|
28 |
+
x = torch.stack((-x2, x1), dim=-1)
|
29 |
+
return rearrange(x, "... d r -> ... (d r)")
|
30 |
+
|
31 |
+
|
32 |
+
class VisionRotaryEmbedding(nn.Module):
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
dim,
|
36 |
+
pt_seq_len,
|
37 |
+
ft_seq_len=None,
|
38 |
+
custom_freqs=None,
|
39 |
+
freqs_for="lang",
|
40 |
+
theta=10000,
|
41 |
+
max_freq=10,
|
42 |
+
num_freqs=1,
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
if custom_freqs:
|
46 |
+
freqs = custom_freqs
|
47 |
+
elif freqs_for == "lang":
|
48 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
49 |
+
elif freqs_for == "pixel":
|
50 |
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
51 |
+
elif freqs_for == "constant":
|
52 |
+
freqs = torch.ones(num_freqs).float()
|
53 |
+
else:
|
54 |
+
raise ValueError(f"unknown modality {freqs_for}")
|
55 |
+
|
56 |
+
if ft_seq_len is None:
|
57 |
+
ft_seq_len = pt_seq_len
|
58 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
59 |
+
|
60 |
+
freqs_h = torch.einsum("..., f -> ... f", t, freqs)
|
61 |
+
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
|
62 |
+
|
63 |
+
freqs_w = torch.einsum("..., f -> ... f", t, freqs)
|
64 |
+
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
|
65 |
+
|
66 |
+
freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
|
67 |
+
|
68 |
+
self.register_buffer("freqs_cos", freqs.cos())
|
69 |
+
self.register_buffer("freqs_sin", freqs.sin())
|
70 |
+
|
71 |
+
logging.info(f"Shape of rope freq: {self.freqs_cos.shape}")
|
72 |
+
|
73 |
+
def forward(self, t, start_index=0):
|
74 |
+
rot_dim = self.freqs_cos.shape[-1]
|
75 |
+
end_index = start_index + rot_dim
|
76 |
+
assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
|
77 |
+
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
|
78 |
+
t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
|
79 |
+
|
80 |
+
return torch.cat((t_left, t, t_right), dim=-1)
|
81 |
+
|
82 |
+
|
83 |
+
class VisionRotaryEmbeddingFast(nn.Module):
|
84 |
+
def __init__(self, dim, pt_seq_len, ft_seq_len=None, custom_freqs=None, freqs_for="lang", theta=10000, max_freq=10, num_freqs=1, patch_dropout=0.0):
|
85 |
+
super().__init__()
|
86 |
+
if custom_freqs:
|
87 |
+
freqs = custom_freqs
|
88 |
+
elif freqs_for == "lang":
|
89 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
90 |
+
elif freqs_for == "pixel":
|
91 |
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
92 |
+
elif freqs_for == "constant":
|
93 |
+
freqs = torch.ones(num_freqs).float()
|
94 |
+
else:
|
95 |
+
raise ValueError(f"unknown modality {freqs_for}")
|
96 |
+
|
97 |
+
if ft_seq_len is None:
|
98 |
+
ft_seq_len = pt_seq_len
|
99 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
100 |
+
|
101 |
+
freqs = torch.einsum("..., f -> ... f", t, freqs)
|
102 |
+
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
103 |
+
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
|
104 |
+
|
105 |
+
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
|
106 |
+
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
|
107 |
+
|
108 |
+
self.patch_dropout = patch_dropout
|
109 |
+
|
110 |
+
self.register_buffer("freqs_cos", freqs_cos)
|
111 |
+
self.register_buffer("freqs_sin", freqs_sin)
|
112 |
+
|
113 |
+
logging.info(f"Shape of rope freq: {self.freqs_cos.shape}")
|
114 |
+
|
115 |
+
def forward(self, t, patch_indices_keep=None):
|
116 |
+
if patch_indices_keep is not None:
|
117 |
+
batch = t.size()[0]
|
118 |
+
batch_indices = torch.arange(batch)
|
119 |
+
batch_indices = batch_indices[..., None]
|
120 |
+
|
121 |
+
freqs_cos = repeat(self.freqs_cos, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
|
122 |
+
freqs_sin = repeat(self.freqs_sin, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
|
123 |
+
|
124 |
+
freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
|
125 |
+
freqs_cos = rearrange(freqs_cos, "n i m j -> n m i j")
|
126 |
+
freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
|
127 |
+
freqs_sin = rearrange(freqs_sin, "n i m j -> n m i j")
|
128 |
+
|
129 |
+
return t * freqs_cos + rotate_half(t) * freqs_sin
|
130 |
+
|
131 |
+
return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
|
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/timm_model.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" timm model adapter
|
2 |
+
|
3 |
+
Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from collections import OrderedDict
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
try:
|
13 |
+
import timm
|
14 |
+
from timm.models.layers import Mlp, to_2tuple
|
15 |
+
|
16 |
+
try:
|
17 |
+
# old timm imports < 0.8.1
|
18 |
+
from timm.models.layers.attention_pool2d import RotAttentionPool2d
|
19 |
+
from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
|
20 |
+
except ImportError:
|
21 |
+
# new timm imports >= 0.8.1
|
22 |
+
from timm.layers import RotAttentionPool2d
|
23 |
+
from timm.layers import AttentionPool2d as AbsAttentionPool2d
|
24 |
+
except ImportError:
|
25 |
+
timm = None
|
26 |
+
|
27 |
+
from .utils import freeze_batch_norm_2d
|
28 |
+
|
29 |
+
|
30 |
+
class TimmModel(nn.Module):
|
31 |
+
"""timm model adapter
|
32 |
+
# FIXME this adapter is a work in progress, may change in ways that break weight compat
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, model_name, embed_dim, image_size=224, pool="avg", proj="linear", proj_bias=False, drop=0.0, pretrained=False):
|
36 |
+
super().__init__()
|
37 |
+
if timm is None:
|
38 |
+
raise RuntimeError("Please `pip install timm` to use timm models.")
|
39 |
+
|
40 |
+
self.image_size = to_2tuple(image_size)
|
41 |
+
self.trunk = timm.create_model(model_name, pretrained=pretrained)
|
42 |
+
feat_size = self.trunk.default_cfg.get("pool_size", None)
|
43 |
+
feature_ndim = 1 if not feat_size else 2
|
44 |
+
if pool in ("abs_attn", "rot_attn"):
|
45 |
+
assert feature_ndim == 2
|
46 |
+
# if attn pooling used, remove both classifier and default pool
|
47 |
+
self.trunk.reset_classifier(0, global_pool="")
|
48 |
+
else:
|
49 |
+
# reset global pool if pool config set, otherwise leave as network default
|
50 |
+
reset_kwargs = dict(global_pool=pool) if pool else {}
|
51 |
+
self.trunk.reset_classifier(0, **reset_kwargs)
|
52 |
+
prev_chs = self.trunk.num_features
|
53 |
+
|
54 |
+
head_layers = OrderedDict()
|
55 |
+
if pool == "abs_attn":
|
56 |
+
head_layers["pool"] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
|
57 |
+
prev_chs = embed_dim
|
58 |
+
elif pool == "rot_attn":
|
59 |
+
head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
|
60 |
+
prev_chs = embed_dim
|
61 |
+
else:
|
62 |
+
assert proj, "projection layer needed if non-attention pooling is used."
|
63 |
+
|
64 |
+
# NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
|
65 |
+
if proj == "linear":
|
66 |
+
head_layers["drop"] = nn.Dropout(drop)
|
67 |
+
head_layers["proj"] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
|
68 |
+
elif proj == "mlp":
|
69 |
+
head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias))
|
70 |
+
|
71 |
+
self.head = nn.Sequential(head_layers)
|
72 |
+
|
73 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
74 |
+
"""lock modules
|
75 |
+
Args:
|
76 |
+
unlocked_groups (int): leave last n layer groups unlocked (default: 0)
|
77 |
+
"""
|
78 |
+
if not unlocked_groups:
|
79 |
+
# lock full model
|
80 |
+
for param in self.trunk.parameters():
|
81 |
+
param.requires_grad = False
|
82 |
+
if freeze_bn_stats:
|
83 |
+
freeze_batch_norm_2d(self.trunk)
|
84 |
+
else:
|
85 |
+
# NOTE: partial freeze requires latest timm (master) branch and is subject to change
|
86 |
+
try:
|
87 |
+
# FIXME import here until API stable and in an official release
|
88 |
+
from timm.models.helpers import group_parameters, group_modules
|
89 |
+
except ImportError:
|
90 |
+
raise RuntimeError("Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`")
|
91 |
+
matcher = self.trunk.group_matcher()
|
92 |
+
gparams = group_parameters(self.trunk, matcher)
|
93 |
+
max_layer_id = max(gparams.keys())
|
94 |
+
max_layer_id = max_layer_id - unlocked_groups
|
95 |
+
for group_idx in range(max_layer_id + 1):
|
96 |
+
group = gparams[group_idx]
|
97 |
+
for param in group:
|
98 |
+
self.trunk.get_parameter(param).requires_grad = False
|
99 |
+
if freeze_bn_stats:
|
100 |
+
gmodules = group_modules(self.trunk, matcher, reverse=True)
|
101 |
+
gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
|
102 |
+
freeze_batch_norm_2d(self.trunk, gmodules)
|
103 |
+
|
104 |
+
@torch.jit.ignore
|
105 |
+
def set_grad_checkpointing(self, enable=True):
|
106 |
+
try:
|
107 |
+
self.trunk.set_grad_checkpointing(enable)
|
108 |
+
except Exception as e:
|
109 |
+
logging.warning("grad checkpointing not supported for this timm image tower, continuing without...")
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
x = self.trunk(x)
|
113 |
+
x = self.head(x)
|
114 |
+
return x
|
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/tokenizer.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" CLIP tokenizer
|
2 |
+
|
3 |
+
Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import gzip
|
7 |
+
import html
|
8 |
+
import os
|
9 |
+
from functools import lru_cache
|
10 |
+
from typing import Union, List
|
11 |
+
|
12 |
+
import ftfy
|
13 |
+
import regex as re
|
14 |
+
import torch
|
15 |
+
|
16 |
+
# https://stackoverflow.com/q/62691279
|
17 |
+
import os
|
18 |
+
|
19 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
20 |
+
|
21 |
+
|
22 |
+
@lru_cache()
|
23 |
+
def default_bpe():
|
24 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
25 |
+
|
26 |
+
|
27 |
+
@lru_cache()
|
28 |
+
def bytes_to_unicode():
|
29 |
+
"""
|
30 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
31 |
+
The reversible bpe codes work on unicode strings.
|
32 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
33 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
34 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
35 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
36 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
37 |
+
"""
|
38 |
+
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
39 |
+
cs = bs[:]
|
40 |
+
n = 0
|
41 |
+
for b in range(2**8):
|
42 |
+
if b not in bs:
|
43 |
+
bs.append(b)
|
44 |
+
cs.append(2**8 + n)
|
45 |
+
n += 1
|
46 |
+
cs = [chr(n) for n in cs]
|
47 |
+
return dict(zip(bs, cs))
|
48 |
+
|
49 |
+
|
50 |
+
def get_pairs(word):
|
51 |
+
"""Return set of symbol pairs in a word.
|
52 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
53 |
+
"""
|
54 |
+
pairs = set()
|
55 |
+
prev_char = word[0]
|
56 |
+
for char in word[1:]:
|
57 |
+
pairs.add((prev_char, char))
|
58 |
+
prev_char = char
|
59 |
+
return pairs
|
60 |
+
|
61 |
+
|
62 |
+
def basic_clean(text):
|
63 |
+
text = ftfy.fix_text(text)
|
64 |
+
text = html.unescape(html.unescape(text))
|
65 |
+
return text.strip()
|
66 |
+
|
67 |
+
|
68 |
+
def whitespace_clean(text):
|
69 |
+
text = re.sub(r"\s+", " ", text)
|
70 |
+
text = text.strip()
|
71 |
+
return text
|
72 |
+
|
73 |
+
|
74 |
+
class SimpleTokenizer(object):
|
75 |
+
def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
|
76 |
+
self.byte_encoder = bytes_to_unicode()
|
77 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
78 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
|
79 |
+
merges = merges[1 : 49152 - 256 - 2 + 1]
|
80 |
+
merges = [tuple(merge.split()) for merge in merges]
|
81 |
+
vocab = list(bytes_to_unicode().values())
|
82 |
+
vocab = vocab + [v + "</w>" for v in vocab]
|
83 |
+
for merge in merges:
|
84 |
+
vocab.append("".join(merge))
|
85 |
+
if not special_tokens:
|
86 |
+
special_tokens = ["<start_of_text>", "<end_of_text>"]
|
87 |
+
else:
|
88 |
+
special_tokens = ["<start_of_text>", "<end_of_text>"] + special_tokens
|
89 |
+
vocab.extend(special_tokens)
|
90 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
91 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
92 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
93 |
+
self.cache = {t: t for t in special_tokens}
|
94 |
+
special = "|".join(special_tokens)
|
95 |
+
self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
96 |
+
|
97 |
+
self.vocab_size = len(self.encoder)
|
98 |
+
self.all_special_ids = [self.encoder[t] for t in special_tokens]
|
99 |
+
|
100 |
+
def bpe(self, token):
|
101 |
+
if token in self.cache:
|
102 |
+
return self.cache[token]
|
103 |
+
word = tuple(token[:-1]) + (token[-1] + "</w>",)
|
104 |
+
pairs = get_pairs(word)
|
105 |
+
|
106 |
+
if not pairs:
|
107 |
+
return token + "</w>"
|
108 |
+
|
109 |
+
while True:
|
110 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
111 |
+
if bigram not in self.bpe_ranks:
|
112 |
+
break
|
113 |
+
first, second = bigram
|
114 |
+
new_word = []
|
115 |
+
i = 0
|
116 |
+
while i < len(word):
|
117 |
+
try:
|
118 |
+
j = word.index(first, i)
|
119 |
+
new_word.extend(word[i:j])
|
120 |
+
i = j
|
121 |
+
except:
|
122 |
+
new_word.extend(word[i:])
|
123 |
+
break
|
124 |
+
|
125 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
126 |
+
new_word.append(first + second)
|
127 |
+
i += 2
|
128 |
+
else:
|
129 |
+
new_word.append(word[i])
|
130 |
+
i += 1
|
131 |
+
new_word = tuple(new_word)
|
132 |
+
word = new_word
|
133 |
+
if len(word) == 1:
|
134 |
+
break
|
135 |
+
else:
|
136 |
+
pairs = get_pairs(word)
|
137 |
+
word = " ".join(word)
|
138 |
+
self.cache[token] = word
|
139 |
+
return word
|
140 |
+
|
141 |
+
def encode(self, text):
|
142 |
+
bpe_tokens = []
|
143 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
144 |
+
for token in re.findall(self.pat, text):
|
145 |
+
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
|
146 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
|
147 |
+
return bpe_tokens
|
148 |
+
|
149 |
+
def decode(self, tokens):
|
150 |
+
text = "".join([self.decoder[token] for token in tokens])
|
151 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors="replace").replace("</w>", " ")
|
152 |
+
return text
|
153 |
+
|
154 |
+
|
155 |
+
_tokenizer = SimpleTokenizer()
|
156 |
+
|
157 |
+
|
158 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
|
159 |
+
"""
|
160 |
+
Returns the tokenized representation of given input string(s)
|
161 |
+
|
162 |
+
Parameters
|
163 |
+
----------
|
164 |
+
texts : Union[str, List[str]]
|
165 |
+
An input string or a list of input strings to tokenize
|
166 |
+
context_length : int
|
167 |
+
The context length to use; all CLIP models use 77 as the context length
|
168 |
+
|
169 |
+
Returns
|
170 |
+
-------
|
171 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
172 |
+
"""
|
173 |
+
if isinstance(texts, str):
|
174 |
+
texts = [texts]
|
175 |
+
|
176 |
+
sot_token = _tokenizer.encoder["<start_of_text>"]
|
177 |
+
eot_token = _tokenizer.encoder["<end_of_text>"]
|
178 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
179 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
180 |
+
|
181 |
+
for i, tokens in enumerate(all_tokens):
|
182 |
+
if len(tokens) > context_length:
|
183 |
+
tokens = tokens[:context_length] # Truncate
|
184 |
+
tokens[-1] = eot_token
|
185 |
+
result[i, : len(tokens)] = torch.tensor(tokens)
|
186 |
+
|
187 |
+
return result
|
188 |
+
|
189 |
+
|
190 |
+
class HFTokenizer:
|
191 |
+
"HuggingFace tokenizer wrapper"
|
192 |
+
|
193 |
+
def __init__(self, tokenizer_name: str):
|
194 |
+
from transformers import AutoTokenizer
|
195 |
+
|
196 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
197 |
+
|
198 |
+
def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor:
|
199 |
+
# same cleaning as for default tokenizer, except lowercasing
|
200 |
+
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
|
201 |
+
if isinstance(texts, str):
|
202 |
+
texts = [texts]
|
203 |
+
texts = [whitespace_clean(basic_clean(text)) for text in texts]
|
204 |
+
input_ids = self.tokenizer(texts, return_tensors="pt", max_length=context_length, padding="max_length", truncation=True).input_ids
|
205 |
+
return input_ids
|
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/transform.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Sequence, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torchvision.transforms.functional as F
|
6 |
+
|
7 |
+
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, CenterCrop
|
8 |
+
|
9 |
+
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
10 |
+
|
11 |
+
|
12 |
+
class ResizeMaxSize(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn="max", fill=0):
|
15 |
+
super().__init__()
|
16 |
+
if not isinstance(max_size, int):
|
17 |
+
raise TypeError(f"Size should be int. Got {type(max_size)}")
|
18 |
+
self.max_size = max_size
|
19 |
+
self.interpolation = interpolation
|
20 |
+
self.fn = min if fn == "min" else min
|
21 |
+
self.fill = fill
|
22 |
+
|
23 |
+
def forward(self, img):
|
24 |
+
if isinstance(img, torch.Tensor):
|
25 |
+
height, width = img.shape[:2]
|
26 |
+
else:
|
27 |
+
width, height = img.size
|
28 |
+
scale = self.max_size / float(max(height, width))
|
29 |
+
if scale != 1.0:
|
30 |
+
new_size = tuple(round(dim * scale) for dim in (height, width))
|
31 |
+
img = F.resize(img, new_size, self.interpolation)
|
32 |
+
pad_h = self.max_size - new_size[0]
|
33 |
+
pad_w = self.max_size - new_size[1]
|
34 |
+
img = F.pad(img, padding=[pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2], fill=self.fill)
|
35 |
+
return img
|
36 |
+
|
37 |
+
|
38 |
+
def _convert_to_rgb(image):
|
39 |
+
return image.convert("RGB")
|
40 |
+
|
41 |
+
|
42 |
+
# class CatGen(nn.Module):
|
43 |
+
# def __init__(self, num=4):
|
44 |
+
# self.num = num
|
45 |
+
# def mixgen_batch(image, text):
|
46 |
+
# batch_size = image.shape[0]
|
47 |
+
# index = np.random.permutation(batch_size)
|
48 |
+
|
49 |
+
# cat_images = []
|
50 |
+
# for i in range(batch_size):
|
51 |
+
# # image mixup
|
52 |
+
# image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:]
|
53 |
+
# # text concat
|
54 |
+
# text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0]
|
55 |
+
# text = torch.stack(text)
|
56 |
+
# return image, text
|
57 |
+
|
58 |
+
|
59 |
+
def image_transform(
|
60 |
+
image_size: int,
|
61 |
+
is_train: bool,
|
62 |
+
mean: Optional[Tuple[float, ...]] = None,
|
63 |
+
std: Optional[Tuple[float, ...]] = None,
|
64 |
+
resize_longest_max: bool = False,
|
65 |
+
fill_color: int = 0,
|
66 |
+
):
|
67 |
+
mean = mean or OPENAI_DATASET_MEAN
|
68 |
+
if not isinstance(mean, (list, tuple)):
|
69 |
+
mean = (mean,) * 3
|
70 |
+
|
71 |
+
std = std or OPENAI_DATASET_STD
|
72 |
+
if not isinstance(std, (list, tuple)):
|
73 |
+
std = (std,) * 3
|
74 |
+
|
75 |
+
if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
|
76 |
+
# for square size, pass size as int so that Resize() uses aspect preserving shortest edge
|
77 |
+
image_size = image_size[0]
|
78 |
+
|
79 |
+
normalize = Normalize(mean=mean, std=std)
|
80 |
+
if is_train:
|
81 |
+
return Compose(
|
82 |
+
[
|
83 |
+
RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
|
84 |
+
_convert_to_rgb,
|
85 |
+
ToTensor(),
|
86 |
+
normalize,
|
87 |
+
]
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
if resize_longest_max:
|
91 |
+
transforms = [ResizeMaxSize(image_size, fill=fill_color)]
|
92 |
+
else:
|
93 |
+
transforms = [
|
94 |
+
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
|
95 |
+
CenterCrop(image_size),
|
96 |
+
]
|
97 |
+
transforms.extend(
|
98 |
+
[
|
99 |
+
_convert_to_rgb,
|
100 |
+
ToTensor(),
|
101 |
+
normalize,
|
102 |
+
]
|
103 |
+
)
|
104 |
+
return Compose(transforms)
|
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/transformer.py
ADDED
@@ -0,0 +1,683 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
from collections import OrderedDict
|
4 |
+
import math
|
5 |
+
from typing import Callable, Optional, Sequence
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
try:
|
12 |
+
from timm.models.layers import trunc_normal_
|
13 |
+
except:
|
14 |
+
from timm.layers import trunc_normal_
|
15 |
+
|
16 |
+
from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
|
17 |
+
from .utils import to_2tuple
|
18 |
+
|
19 |
+
if os.getenv("ENV_TYPE") == "deepspeed":
|
20 |
+
try:
|
21 |
+
import deepspeed
|
22 |
+
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
|
23 |
+
except:
|
24 |
+
print("Please 'pip install deepspeed'")
|
25 |
+
deepspeed = None
|
26 |
+
from torch.utils.checkpoint import checkpoint
|
27 |
+
else:
|
28 |
+
from torch.utils.checkpoint import checkpoint
|
29 |
+
|
30 |
+
try:
|
31 |
+
import xformers.ops as xops
|
32 |
+
except ImportError:
|
33 |
+
xops = None
|
34 |
+
# print("Please 'pip install xformers'")
|
35 |
+
|
36 |
+
|
37 |
+
class LayerNormFp32(nn.LayerNorm):
|
38 |
+
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
|
39 |
+
|
40 |
+
def __init__(self, *args, **kwargs):
|
41 |
+
super().__init__(*args, **kwargs)
|
42 |
+
|
43 |
+
def forward(self, x: torch.Tensor):
|
44 |
+
output = F.layer_norm(
|
45 |
+
x.float(),
|
46 |
+
self.normalized_shape,
|
47 |
+
self.weight.float() if self.weight is not None else None,
|
48 |
+
self.bias.float() if self.bias is not None else None,
|
49 |
+
self.eps,
|
50 |
+
)
|
51 |
+
return output.type_as(x)
|
52 |
+
|
53 |
+
|
54 |
+
class LayerNorm(nn.LayerNorm):
|
55 |
+
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
|
56 |
+
|
57 |
+
def forward(self, x: torch.Tensor):
|
58 |
+
orig_type = x.dtype
|
59 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
60 |
+
return x.to(orig_type)
|
61 |
+
|
62 |
+
|
63 |
+
class QuickGELU(nn.Module):
|
64 |
+
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
|
65 |
+
def forward(self, x: torch.Tensor):
|
66 |
+
return x * torch.sigmoid(1.702 * x)
|
67 |
+
|
68 |
+
|
69 |
+
class LayerScale(nn.Module):
|
70 |
+
def __init__(self, dim, init_values=1e-5, inplace=False):
|
71 |
+
super().__init__()
|
72 |
+
self.inplace = inplace
|
73 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
77 |
+
|
78 |
+
|
79 |
+
class PatchDropout(nn.Module):
|
80 |
+
"""
|
81 |
+
https://arxiv.org/abs/2212.00794
|
82 |
+
"""
|
83 |
+
|
84 |
+
def __init__(self, prob, exclude_first_token=True):
|
85 |
+
super().__init__()
|
86 |
+
assert 0 <= prob < 1.0
|
87 |
+
self.prob = prob
|
88 |
+
self.exclude_first_token = exclude_first_token # exclude CLS token
|
89 |
+
logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
if not self.training or self.prob == 0.0:
|
93 |
+
return x
|
94 |
+
|
95 |
+
if self.exclude_first_token:
|
96 |
+
cls_tokens, x = x[:, :1], x[:, 1:]
|
97 |
+
else:
|
98 |
+
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
|
99 |
+
|
100 |
+
batch = x.size()[0]
|
101 |
+
num_tokens = x.size()[1]
|
102 |
+
|
103 |
+
batch_indices = torch.arange(batch)
|
104 |
+
batch_indices = batch_indices[..., None]
|
105 |
+
|
106 |
+
keep_prob = 1 - self.prob
|
107 |
+
num_patches_keep = max(1, int(num_tokens * keep_prob))
|
108 |
+
|
109 |
+
rand = torch.randn(batch, num_tokens)
|
110 |
+
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
111 |
+
|
112 |
+
x = x[batch_indices, patch_indices_keep]
|
113 |
+
|
114 |
+
if self.exclude_first_token:
|
115 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
116 |
+
|
117 |
+
if self.training and os.getenv("RoPE") == "1":
|
118 |
+
return x, patch_indices_keep
|
119 |
+
|
120 |
+
return x
|
121 |
+
|
122 |
+
|
123 |
+
def _in_projection_packed(
|
124 |
+
q: torch.Tensor,
|
125 |
+
k: torch.Tensor,
|
126 |
+
v: torch.Tensor,
|
127 |
+
w: torch.Tensor,
|
128 |
+
b: Optional[torch.Tensor] = None,
|
129 |
+
):
|
130 |
+
"""
|
131 |
+
https://github.com/pytorch/pytorch/blob/db2a237763eb8693a20788be94f8c192e762baa8/torch/nn/functional.py#L4726
|
132 |
+
"""
|
133 |
+
E = q.size(-1)
|
134 |
+
if k is v:
|
135 |
+
if q is k:
|
136 |
+
# self-attention
|
137 |
+
return F.linear(q, w, b).chunk(3, dim=-1)
|
138 |
+
else:
|
139 |
+
# encoder-decoder attention
|
140 |
+
w_q, w_kv = w.split([E, E * 2])
|
141 |
+
if b is None:
|
142 |
+
b_q = b_kv = None
|
143 |
+
else:
|
144 |
+
b_q, b_kv = b.split([E, E * 2])
|
145 |
+
return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
|
146 |
+
else:
|
147 |
+
w_q, w_k, w_v = w.chunk(3)
|
148 |
+
if b is None:
|
149 |
+
b_q = b_k = b_v = None
|
150 |
+
else:
|
151 |
+
b_q, b_k, b_v = b.chunk(3)
|
152 |
+
return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
|
153 |
+
|
154 |
+
|
155 |
+
class Attention(nn.Module):
|
156 |
+
def __init__(self, dim, num_heads=8, qkv_bias=True, scaled_cosine=False, scale_heads=False, logit_scale_max=math.log(1.0 / 0.01), attn_drop=0.0, proj_drop=0.0, xattn=False, rope=False):
|
157 |
+
super().__init__()
|
158 |
+
self.scaled_cosine = scaled_cosine
|
159 |
+
self.scale_heads = scale_heads
|
160 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
161 |
+
self.num_heads = num_heads
|
162 |
+
self.head_dim = dim // num_heads
|
163 |
+
self.scale = self.head_dim**-0.5
|
164 |
+
self.logit_scale_max = logit_scale_max
|
165 |
+
|
166 |
+
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
|
167 |
+
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
|
168 |
+
if qkv_bias:
|
169 |
+
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
|
170 |
+
else:
|
171 |
+
self.in_proj_bias = None
|
172 |
+
|
173 |
+
if self.scaled_cosine:
|
174 |
+
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
|
175 |
+
else:
|
176 |
+
self.logit_scale = None
|
177 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
178 |
+
if self.scale_heads:
|
179 |
+
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
|
180 |
+
else:
|
181 |
+
self.head_scale = None
|
182 |
+
self.out_proj = nn.Linear(dim, dim)
|
183 |
+
self.out_drop = nn.Dropout(proj_drop)
|
184 |
+
self.xattn = xattn
|
185 |
+
self.xattn_drop = attn_drop
|
186 |
+
self.rope = rope
|
187 |
+
|
188 |
+
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
189 |
+
L, N, C = x.shape
|
190 |
+
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
|
191 |
+
if self.xattn:
|
192 |
+
q = q.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
|
193 |
+
k = k.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
|
194 |
+
v = v.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
|
195 |
+
|
196 |
+
x = xops.memory_efficient_attention(
|
197 |
+
q,
|
198 |
+
k,
|
199 |
+
v,
|
200 |
+
p=self.xattn_drop,
|
201 |
+
scale=self.scale if self.logit_scale is None else None,
|
202 |
+
attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None,
|
203 |
+
)
|
204 |
+
else:
|
205 |
+
q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
206 |
+
k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
207 |
+
v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
208 |
+
|
209 |
+
if self.logit_scale is not None:
|
210 |
+
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
|
211 |
+
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
|
212 |
+
attn = attn.view(N, self.num_heads, L, L) * logit_scale
|
213 |
+
attn = attn.view(-1, L, L)
|
214 |
+
else:
|
215 |
+
q = q * self.scale
|
216 |
+
attn = torch.bmm(q, k.transpose(-1, -2))
|
217 |
+
|
218 |
+
if attn_mask is not None:
|
219 |
+
if attn_mask.dtype == torch.bool:
|
220 |
+
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
221 |
+
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
222 |
+
attn_mask = new_attn_mask
|
223 |
+
attn += attn_mask
|
224 |
+
|
225 |
+
attn = attn.softmax(dim=-1)
|
226 |
+
attn = self.attn_drop(attn)
|
227 |
+
|
228 |
+
x = torch.bmm(attn, v)
|
229 |
+
|
230 |
+
if self.head_scale is not None:
|
231 |
+
x = x.view(N, self.num_heads, L, C) * self.head_scale
|
232 |
+
x = x.view(-1, L, C)
|
233 |
+
x = x.transpose(0, 1).reshape(L, N, C)
|
234 |
+
x = self.out_proj(x)
|
235 |
+
x = self.out_drop(x)
|
236 |
+
return x
|
237 |
+
|
238 |
+
|
239 |
+
class CustomAttention(nn.Module):
|
240 |
+
def __init__(self, dim, num_heads=8, qkv_bias=True, scaled_cosine=True, scale_heads=False, logit_scale_max=math.log(1.0 / 0.01), attn_drop=0.0, proj_drop=0.0, xattn=False):
|
241 |
+
super().__init__()
|
242 |
+
self.scaled_cosine = scaled_cosine
|
243 |
+
self.scale_heads = scale_heads
|
244 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
245 |
+
self.num_heads = num_heads
|
246 |
+
self.head_dim = dim // num_heads
|
247 |
+
self.scale = self.head_dim**-0.5
|
248 |
+
self.logit_scale_max = logit_scale_max
|
249 |
+
|
250 |
+
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
|
251 |
+
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
|
252 |
+
if qkv_bias:
|
253 |
+
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
|
254 |
+
else:
|
255 |
+
self.in_proj_bias = None
|
256 |
+
|
257 |
+
if self.scaled_cosine:
|
258 |
+
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
|
259 |
+
else:
|
260 |
+
self.logit_scale = None
|
261 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
262 |
+
if self.scale_heads:
|
263 |
+
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
|
264 |
+
else:
|
265 |
+
self.head_scale = None
|
266 |
+
self.out_proj = nn.Linear(dim, dim)
|
267 |
+
self.out_drop = nn.Dropout(proj_drop)
|
268 |
+
self.xattn = xattn
|
269 |
+
self.xattn_drop = attn_drop
|
270 |
+
|
271 |
+
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
272 |
+
q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
|
273 |
+
N_q, B_q, C_q = q.shape
|
274 |
+
N_k, B_k, C_k = k.shape
|
275 |
+
N_v, B_v, C_v = v.shape
|
276 |
+
if self.xattn:
|
277 |
+
# B, N, C -> B, N, num_heads, C
|
278 |
+
q = q.permute(1, 0, 2).reshape(B_q, N_q, self.num_heads, -1)
|
279 |
+
k = k.permute(1, 0, 2).reshape(B_k, N_k, self.num_heads, -1)
|
280 |
+
v = v.permute(1, 0, 2).reshape(B_v, N_v, self.num_heads, -1)
|
281 |
+
|
282 |
+
x = xops.memory_efficient_attention(q, k, v, p=self.xattn_drop, scale=self.scale if self.logit_scale is None else None, attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None)
|
283 |
+
else:
|
284 |
+
# B*H, L, C
|
285 |
+
q = q.contiguous().view(N_q, B_q * self.num_heads, -1).transpose(0, 1)
|
286 |
+
k = k.contiguous().view(N_k, B_k * self.num_heads, -1).transpose(0, 1)
|
287 |
+
v = v.contiguous().view(N_v, B_v * self.num_heads, -1).transpose(0, 1)
|
288 |
+
|
289 |
+
if self.logit_scale is not None:
|
290 |
+
# B*H, N_q, N_k
|
291 |
+
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
|
292 |
+
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
|
293 |
+
attn = attn.view(B_q, self.num_heads, N_q, N_k) * logit_scale
|
294 |
+
attn = attn.view(-1, N_q, N_k)
|
295 |
+
else:
|
296 |
+
q = q * self.scale
|
297 |
+
attn = torch.bmm(q, k.transpose(-1, -2))
|
298 |
+
|
299 |
+
if attn_mask is not None:
|
300 |
+
if attn_mask.dtype == torch.bool:
|
301 |
+
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
302 |
+
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
303 |
+
attn_mask = new_attn_mask
|
304 |
+
attn += attn_mask
|
305 |
+
|
306 |
+
attn = attn.softmax(dim=-1)
|
307 |
+
attn = self.attn_drop(attn)
|
308 |
+
|
309 |
+
x = torch.bmm(attn, v)
|
310 |
+
|
311 |
+
if self.head_scale is not None:
|
312 |
+
x = x.view(B_q, self.num_heads, N_q, C_q) * self.head_scale
|
313 |
+
x = x.view(-1, N_q, C_q)
|
314 |
+
x = x.transpose(0, 1).reshape(N_q, B_q, C_q)
|
315 |
+
x = self.out_proj(x)
|
316 |
+
x = self.out_drop(x)
|
317 |
+
return x
|
318 |
+
|
319 |
+
|
320 |
+
class CustomResidualAttentionBlock(nn.Module):
|
321 |
+
def __init__(
|
322 |
+
self,
|
323 |
+
d_model: int,
|
324 |
+
n_head: int,
|
325 |
+
mlp_ratio: float = 4.0,
|
326 |
+
ls_init_value: float = None,
|
327 |
+
act_layer: Callable = nn.GELU,
|
328 |
+
norm_layer: Callable = LayerNorm,
|
329 |
+
scale_cosine_attn: bool = False,
|
330 |
+
scale_heads: bool = False,
|
331 |
+
scale_attn: bool = False,
|
332 |
+
scale_fc: bool = False,
|
333 |
+
cross_attn: bool = False,
|
334 |
+
xattn: bool = False,
|
335 |
+
):
|
336 |
+
super().__init__()
|
337 |
+
|
338 |
+
self.ln_1 = norm_layer(d_model)
|
339 |
+
self.ln_1_k = norm_layer(d_model) if cross_attn else self.ln_1
|
340 |
+
self.ln_1_v = norm_layer(d_model) if cross_attn else self.ln_1
|
341 |
+
self.attn = CustomAttention(d_model, n_head, qkv_bias=True, attn_drop=0.0, proj_drop=0.0, scaled_cosine=scale_cosine_attn, scale_heads=scale_heads, xattn=xattn)
|
342 |
+
|
343 |
+
self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
|
344 |
+
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
345 |
+
|
346 |
+
self.ln_2 = norm_layer(d_model)
|
347 |
+
mlp_width = int(d_model * mlp_ratio)
|
348 |
+
self.mlp = nn.Sequential(OrderedDict([("c_fc", nn.Linear(d_model, mlp_width)), ("ln", norm_layer(mlp_width) if scale_fc else nn.Identity()), ("gelu", act_layer()), ("c_proj", nn.Linear(mlp_width, d_model))]))
|
349 |
+
|
350 |
+
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
351 |
+
|
352 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
353 |
+
q = q + self.ls_1(self.ln_attn(self.attn(self.ln_1(q), self.ln_1_k(k), self.ln_1_v(v), attn_mask=attn_mask)))
|
354 |
+
q = q + self.ls_2(self.mlp(self.ln_2(q)))
|
355 |
+
return q
|
356 |
+
|
357 |
+
|
358 |
+
class CustomTransformer(nn.Module):
|
359 |
+
def __init__(
|
360 |
+
self,
|
361 |
+
width: int,
|
362 |
+
layers: int,
|
363 |
+
heads: int,
|
364 |
+
mlp_ratio: float = 4.0,
|
365 |
+
ls_init_value: float = None,
|
366 |
+
act_layer: Callable = nn.GELU,
|
367 |
+
norm_layer: Callable = LayerNorm,
|
368 |
+
scale_cosine_attn: bool = True,
|
369 |
+
scale_heads: bool = False,
|
370 |
+
scale_attn: bool = False,
|
371 |
+
scale_fc: bool = False,
|
372 |
+
cross_attn: bool = False,
|
373 |
+
xattn: bool = False,
|
374 |
+
):
|
375 |
+
super().__init__()
|
376 |
+
self.width = width
|
377 |
+
self.layers = layers
|
378 |
+
self.grad_checkpointing = False
|
379 |
+
self.xattn = xattn
|
380 |
+
|
381 |
+
self.resblocks = nn.ModuleList(
|
382 |
+
[
|
383 |
+
CustomResidualAttentionBlock(
|
384 |
+
width,
|
385 |
+
heads,
|
386 |
+
mlp_ratio,
|
387 |
+
ls_init_value=ls_init_value,
|
388 |
+
act_layer=act_layer,
|
389 |
+
norm_layer=norm_layer,
|
390 |
+
scale_cosine_attn=scale_cosine_attn,
|
391 |
+
scale_heads=scale_heads,
|
392 |
+
scale_attn=scale_attn,
|
393 |
+
scale_fc=scale_fc,
|
394 |
+
cross_attn=cross_attn,
|
395 |
+
xattn=xattn,
|
396 |
+
)
|
397 |
+
for _ in range(layers)
|
398 |
+
]
|
399 |
+
)
|
400 |
+
|
401 |
+
def get_cast_dtype(self) -> torch.dtype:
|
402 |
+
return self.resblocks[0].mlp.c_fc.weight.dtype
|
403 |
+
|
404 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor = None, v: torch.Tensor = None, attn_mask: Optional[torch.Tensor] = None):
|
405 |
+
if k is None and v is None:
|
406 |
+
k = v = q
|
407 |
+
for r in self.resblocks:
|
408 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
409 |
+
q = checkpoint(r, q, k, v, attn_mask)
|
410 |
+
else:
|
411 |
+
q = r(q, k, v, attn_mask=attn_mask)
|
412 |
+
return q
|
413 |
+
|
414 |
+
|
415 |
+
class ResidualAttentionBlock(nn.Module):
|
416 |
+
def __init__(
|
417 |
+
self,
|
418 |
+
d_model: int,
|
419 |
+
n_head: int,
|
420 |
+
mlp_ratio: float = 4.0,
|
421 |
+
ls_init_value: float = None,
|
422 |
+
act_layer: Callable = nn.GELU,
|
423 |
+
norm_layer: Callable = LayerNorm,
|
424 |
+
xattn: bool = False,
|
425 |
+
):
|
426 |
+
super().__init__()
|
427 |
+
|
428 |
+
self.ln_1 = norm_layer(d_model)
|
429 |
+
if xattn:
|
430 |
+
self.attn = Attention(d_model, n_head, xattn=True)
|
431 |
+
else:
|
432 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
433 |
+
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
434 |
+
|
435 |
+
self.ln_2 = norm_layer(d_model)
|
436 |
+
mlp_width = int(d_model * mlp_ratio)
|
437 |
+
self.mlp = nn.Sequential(OrderedDict([("c_fc", nn.Linear(d_model, mlp_width)), ("gelu", act_layer()), ("c_proj", nn.Linear(mlp_width, d_model))]))
|
438 |
+
|
439 |
+
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
440 |
+
self.xattn = xattn
|
441 |
+
|
442 |
+
def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
443 |
+
attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
|
444 |
+
if self.xattn:
|
445 |
+
return self.attn(x, attn_mask=attn_mask)
|
446 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
|
447 |
+
|
448 |
+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
449 |
+
x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask))
|
450 |
+
x = x + self.ls_2(self.mlp(self.ln_2(x)))
|
451 |
+
return x
|
452 |
+
|
453 |
+
|
454 |
+
class Transformer(nn.Module):
|
455 |
+
def __init__(
|
456 |
+
self,
|
457 |
+
width: int,
|
458 |
+
layers: int,
|
459 |
+
heads: int,
|
460 |
+
mlp_ratio: float = 4.0,
|
461 |
+
ls_init_value: float = None,
|
462 |
+
act_layer: Callable = nn.GELU,
|
463 |
+
norm_layer: Callable = LayerNorm,
|
464 |
+
xattn: bool = False,
|
465 |
+
):
|
466 |
+
super().__init__()
|
467 |
+
self.width = width
|
468 |
+
self.layers = layers
|
469 |
+
self.grad_checkpointing = False
|
470 |
+
|
471 |
+
self.resblocks = nn.ModuleList([ResidualAttentionBlock(width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn) for _ in range(layers)])
|
472 |
+
|
473 |
+
def get_cast_dtype(self) -> torch.dtype:
|
474 |
+
return self.resblocks[0].mlp.c_fc.weight.dtype
|
475 |
+
|
476 |
+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
477 |
+
for r in self.resblocks:
|
478 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
479 |
+
x = checkpoint(r, x, attn_mask)
|
480 |
+
else:
|
481 |
+
x = r(x, attn_mask=attn_mask)
|
482 |
+
return x
|
483 |
+
|
484 |
+
|
485 |
+
class VisionTransformer(nn.Module):
|
486 |
+
def __init__(
|
487 |
+
self,
|
488 |
+
image_size: int,
|
489 |
+
patch_size: int,
|
490 |
+
width: int,
|
491 |
+
layers: int,
|
492 |
+
heads: int,
|
493 |
+
mlp_ratio: float,
|
494 |
+
ls_init_value: float = None,
|
495 |
+
patch_dropout: float = 0.0,
|
496 |
+
global_average_pool: bool = False,
|
497 |
+
output_dim: int = 512,
|
498 |
+
act_layer: Callable = nn.GELU,
|
499 |
+
norm_layer: Callable = LayerNorm,
|
500 |
+
xattn: bool = False,
|
501 |
+
):
|
502 |
+
super().__init__()
|
503 |
+
self.image_size = to_2tuple(image_size)
|
504 |
+
self.patch_size = to_2tuple(patch_size)
|
505 |
+
self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])
|
506 |
+
self.output_dim = output_dim
|
507 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
508 |
+
|
509 |
+
scale = width**-0.5
|
510 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
511 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
|
512 |
+
|
513 |
+
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
|
514 |
+
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity()
|
515 |
+
self.ln_pre = norm_layer(width)
|
516 |
+
|
517 |
+
self.transformer = Transformer(width, layers, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn)
|
518 |
+
|
519 |
+
self.global_average_pool = global_average_pool
|
520 |
+
self.ln_post = norm_layer(width)
|
521 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
522 |
+
|
523 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
524 |
+
for param in self.parameters():
|
525 |
+
param.requires_grad = False
|
526 |
+
|
527 |
+
if unlocked_groups != 0:
|
528 |
+
groups = [
|
529 |
+
[
|
530 |
+
self.conv1,
|
531 |
+
self.class_embedding,
|
532 |
+
self.positional_embedding,
|
533 |
+
self.ln_pre,
|
534 |
+
],
|
535 |
+
*self.transformer.resblocks[:-1],
|
536 |
+
[
|
537 |
+
self.transformer.resblocks[-1],
|
538 |
+
self.ln_post,
|
539 |
+
],
|
540 |
+
self.proj,
|
541 |
+
]
|
542 |
+
|
543 |
+
def _unlock(x):
|
544 |
+
if isinstance(x, Sequence):
|
545 |
+
for g in x:
|
546 |
+
_unlock(g)
|
547 |
+
else:
|
548 |
+
if isinstance(x, torch.nn.Parameter):
|
549 |
+
x.requires_grad = True
|
550 |
+
else:
|
551 |
+
for p in x.parameters():
|
552 |
+
p.requires_grad = True
|
553 |
+
|
554 |
+
_unlock(groups[-unlocked_groups:])
|
555 |
+
|
556 |
+
def get_num_layers(self):
|
557 |
+
return self.transformer.layers
|
558 |
+
|
559 |
+
@torch.jit.ignore
|
560 |
+
def set_grad_checkpointing(self, enable=True):
|
561 |
+
self.transformer.grad_checkpointing = enable
|
562 |
+
|
563 |
+
@torch.jit.ignore
|
564 |
+
def no_weight_decay(self):
|
565 |
+
return {"positional_embedding", "class_embedding"}
|
566 |
+
|
567 |
+
def forward(self, x: torch.Tensor, return_all_features: bool = False):
|
568 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
569 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
570 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
571 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
572 |
+
x = x + self.positional_embedding.to(x.dtype)
|
573 |
+
|
574 |
+
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
575 |
+
x = self.patch_dropout(x)
|
576 |
+
x = self.ln_pre(x)
|
577 |
+
|
578 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
579 |
+
x = self.transformer(x)
|
580 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
581 |
+
|
582 |
+
if not return_all_features:
|
583 |
+
if self.global_average_pool:
|
584 |
+
x = x.mean(dim=1) # x = x[:,1:,:].mean(dim=1)
|
585 |
+
else:
|
586 |
+
x = x[:, 0]
|
587 |
+
|
588 |
+
x = self.ln_post(x)
|
589 |
+
|
590 |
+
if self.proj is not None:
|
591 |
+
x = x @ self.proj
|
592 |
+
|
593 |
+
return x
|
594 |
+
|
595 |
+
|
596 |
+
class TextTransformer(nn.Module):
|
597 |
+
def __init__(
|
598 |
+
self,
|
599 |
+
context_length: int = 77,
|
600 |
+
vocab_size: int = 49408,
|
601 |
+
width: int = 512,
|
602 |
+
heads: int = 8,
|
603 |
+
layers: int = 12,
|
604 |
+
ls_init_value: float = None,
|
605 |
+
output_dim: int = 512,
|
606 |
+
act_layer: Callable = nn.GELU,
|
607 |
+
norm_layer: Callable = LayerNorm,
|
608 |
+
xattn: bool = False,
|
609 |
+
attn_mask: bool = True,
|
610 |
+
):
|
611 |
+
super().__init__()
|
612 |
+
self.context_length = context_length
|
613 |
+
self.vocab_size = vocab_size
|
614 |
+
self.width = width
|
615 |
+
self.output_dim = output_dim
|
616 |
+
|
617 |
+
self.token_embedding = nn.Embedding(vocab_size, width)
|
618 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
|
619 |
+
self.transformer = Transformer(width=width, layers=layers, heads=heads, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn)
|
620 |
+
|
621 |
+
self.xattn = xattn
|
622 |
+
self.ln_final = norm_layer(width)
|
623 |
+
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
|
624 |
+
|
625 |
+
if attn_mask:
|
626 |
+
self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
|
627 |
+
else:
|
628 |
+
self.attn_mask = None
|
629 |
+
|
630 |
+
self.init_parameters()
|
631 |
+
|
632 |
+
def init_parameters(self):
|
633 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
634 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
635 |
+
|
636 |
+
proj_std = (self.transformer.width**-0.5) * ((2 * self.transformer.layers) ** -0.5)
|
637 |
+
attn_std = self.transformer.width**-0.5
|
638 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
639 |
+
for block in self.transformer.resblocks:
|
640 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
641 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
642 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
643 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
644 |
+
|
645 |
+
if self.text_projection is not None:
|
646 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5)
|
647 |
+
|
648 |
+
@torch.jit.ignore
|
649 |
+
def set_grad_checkpointing(self, enable=True):
|
650 |
+
self.transformer.grad_checkpointing = enable
|
651 |
+
|
652 |
+
@torch.jit.ignore
|
653 |
+
def no_weight_decay(self):
|
654 |
+
# return {'positional_embedding', 'token_embedding'}
|
655 |
+
return {"positional_embedding"}
|
656 |
+
|
657 |
+
def get_num_layers(self):
|
658 |
+
return self.transformer.layers
|
659 |
+
|
660 |
+
def build_attention_mask(self):
|
661 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
662 |
+
# pytorch uses additive attention mask; fill with -inf
|
663 |
+
mask = torch.empty(self.context_length, self.context_length)
|
664 |
+
mask.fill_(float("-inf"))
|
665 |
+
mask.triu_(1) # zero out the lower diagonal
|
666 |
+
return mask
|
667 |
+
|
668 |
+
def forward(self, text, return_all_features: bool = False):
|
669 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
670 |
+
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
671 |
+
|
672 |
+
x = x + self.positional_embedding.to(cast_dtype)
|
673 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
674 |
+
x = self.transformer(x, attn_mask=self.attn_mask)
|
675 |
+
# x = self.transformer(x) # no attention mask is applied
|
676 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
677 |
+
x = self.ln_final(x)
|
678 |
+
|
679 |
+
if not return_all_features:
|
680 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
681 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
682 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
683 |
+
return x
|
blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/utils.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from itertools import repeat
|
2 |
+
import collections.abc
|
3 |
+
import logging
|
4 |
+
import math
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn as nn
|
9 |
+
from torchvision.ops.misc import FrozenBatchNorm2d
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
# open CLIP
|
14 |
+
def resize_clip_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
|
15 |
+
# Rescale the grid of position embeddings when loading from state_dict
|
16 |
+
old_pos_embed = state_dict.get("visual.positional_embedding", None)
|
17 |
+
if old_pos_embed is None or not hasattr(model.visual, "grid_size"):
|
18 |
+
return
|
19 |
+
grid_size = to_2tuple(model.visual.grid_size)
|
20 |
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
21 |
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
22 |
+
if new_seq_len == old_pos_embed.shape[0]:
|
23 |
+
return
|
24 |
+
|
25 |
+
if extra_tokens:
|
26 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
27 |
+
else:
|
28 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
29 |
+
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
30 |
+
|
31 |
+
logging.info("Resizing position embedding grid-size from %s to %s", old_grid_size, grid_size)
|
32 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
33 |
+
pos_emb_img = F.interpolate(
|
34 |
+
pos_emb_img,
|
35 |
+
size=grid_size,
|
36 |
+
mode=interpolation,
|
37 |
+
align_corners=True,
|
38 |
+
)
|
39 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
40 |
+
if pos_emb_tok is not None:
|
41 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
42 |
+
else:
|
43 |
+
new_pos_embed = pos_emb_img
|
44 |
+
state_dict["visual.positional_embedding"] = new_pos_embed
|
45 |
+
|
46 |
+
|
47 |
+
def resize_visual_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
|
48 |
+
# Rescale the grid of position embeddings when loading from state_dict
|
49 |
+
old_pos_embed = state_dict.get("positional_embedding", None)
|
50 |
+
if old_pos_embed is None or not hasattr(model.visual, "grid_size"):
|
51 |
+
return
|
52 |
+
grid_size = to_2tuple(model.visual.grid_size)
|
53 |
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
54 |
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
55 |
+
if new_seq_len == old_pos_embed.shape[0]:
|
56 |
+
return
|
57 |
+
|
58 |
+
if extra_tokens:
|
59 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
60 |
+
else:
|
61 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
62 |
+
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
63 |
+
|
64 |
+
logging.info("Resizing position embedding grid-size from %s to %s", old_grid_size, grid_size)
|
65 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
66 |
+
pos_emb_img = F.interpolate(
|
67 |
+
pos_emb_img,
|
68 |
+
size=grid_size,
|
69 |
+
mode=interpolation,
|
70 |
+
align_corners=True,
|
71 |
+
)
|
72 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
73 |
+
if pos_emb_tok is not None:
|
74 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
75 |
+
else:
|
76 |
+
new_pos_embed = pos_emb_img
|
77 |
+
state_dict["positional_embedding"] = new_pos_embed
|
78 |
+
|
79 |
+
|
80 |
+
def resize_evaclip_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
|
81 |
+
all_keys = list(state_dict.keys())
|
82 |
+
# interpolate position embedding
|
83 |
+
if "visual.pos_embed" in state_dict:
|
84 |
+
pos_embed_checkpoint = state_dict["visual.pos_embed"]
|
85 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
86 |
+
num_patches = model.visual.patch_embed.num_patches
|
87 |
+
# num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
|
88 |
+
num_extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
89 |
+
# height (== width) for the checkpoint position embedding
|
90 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
91 |
+
# height (== width) for the new position embedding
|
92 |
+
new_size = int(num_patches**0.5)
|
93 |
+
# class_token and dist_token are kept unchanged
|
94 |
+
if orig_size != new_size:
|
95 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
96 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
97 |
+
# only the position tokens are interpolated
|
98 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
99 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
100 |
+
pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False)
|
101 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
102 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
103 |
+
state_dict["visual.pos_embed"] = new_pos_embed
|
104 |
+
|
105 |
+
patch_embed_proj = state_dict["visual.patch_embed.proj.weight"]
|
106 |
+
patch_size = model.visual.patch_embed.patch_size
|
107 |
+
state_dict["visual.patch_embed.proj.weight"] = torch.nn.functional.interpolate(patch_embed_proj.float(), size=patch_size, mode="bicubic", align_corners=False)
|
108 |
+
|
109 |
+
|
110 |
+
def resize_eva_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
|
111 |
+
all_keys = list(state_dict.keys())
|
112 |
+
# interpolate position embedding
|
113 |
+
if "pos_embed" in state_dict:
|
114 |
+
pos_embed_checkpoint = state_dict["pos_embed"]
|
115 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
116 |
+
num_patches = model.visual.patch_embed.num_patches
|
117 |
+
# num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
|
118 |
+
num_extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
119 |
+
# height (== width) for the checkpoint position embedding
|
120 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
121 |
+
# height (== width) for the new position embedding
|
122 |
+
new_size = int(num_patches**0.5)
|
123 |
+
# class_token and dist_token are kept unchanged
|
124 |
+
if orig_size != new_size:
|
125 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
126 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
127 |
+
# only the position tokens are interpolated
|
128 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
129 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
130 |
+
pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False)
|
131 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
132 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
133 |
+
state_dict["pos_embed"] = new_pos_embed
|
134 |
+
|
135 |
+
patch_embed_proj = state_dict["patch_embed.proj.weight"]
|
136 |
+
patch_size = model.visual.patch_embed.patch_size
|
137 |
+
state_dict["patch_embed.proj.weight"] = torch.nn.functional.interpolate(patch_embed_proj.float(), size=patch_size, mode="bicubic", align_corners=False)
|
138 |
+
|
139 |
+
|
140 |
+
def resize_rel_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1):
|
141 |
+
all_keys = list(state_dict.keys())
|
142 |
+
for key in all_keys:
|
143 |
+
if "relative_position_index" in key:
|
144 |
+
state_dict.pop(key)
|
145 |
+
|
146 |
+
if "relative_position_bias_table" in key:
|
147 |
+
rel_pos_bias = state_dict[key]
|
148 |
+
src_num_pos, num_attn_heads = rel_pos_bias.size()
|
149 |
+
dst_num_pos, _ = model.visual.state_dict()[key].size()
|
150 |
+
dst_patch_shape = model.visual.patch_embed.patch_shape
|
151 |
+
if dst_patch_shape[0] != dst_patch_shape[1]:
|
152 |
+
raise NotImplementedError()
|
153 |
+
num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
|
154 |
+
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
|
155 |
+
dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
|
156 |
+
if src_size != dst_size:
|
157 |
+
print("Position interpolate for %s from %dx%d to %dx%d" % (key, src_size, src_size, dst_size, dst_size))
|
158 |
+
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
|
159 |
+
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
|
160 |
+
|
161 |
+
def geometric_progression(a, r, n):
|
162 |
+
return a * (1.0 - r**n) / (1.0 - r)
|
163 |
+
|
164 |
+
left, right = 1.01, 1.5
|
165 |
+
while right - left > 1e-6:
|
166 |
+
q = (left + right) / 2.0
|
167 |
+
gp = geometric_progression(1, q, src_size // 2)
|
168 |
+
if gp > dst_size // 2:
|
169 |
+
right = q
|
170 |
+
else:
|
171 |
+
left = q
|
172 |
+
|
173 |
+
# if q > 1.090307:
|
174 |
+
# q = 1.090307
|
175 |
+
|
176 |
+
dis = []
|
177 |
+
cur = 1
|
178 |
+
for i in range(src_size // 2):
|
179 |
+
dis.append(cur)
|
180 |
+
cur += q ** (i + 1)
|
181 |
+
|
182 |
+
r_ids = [-_ for _ in reversed(dis)]
|
183 |
+
|
184 |
+
x = r_ids + [0] + dis
|
185 |
+
y = r_ids + [0] + dis
|
186 |
+
|
187 |
+
t = dst_size // 2.0
|
188 |
+
dx = np.arange(-t, t + 0.1, 1.0)
|
189 |
+
dy = np.arange(-t, t + 0.1, 1.0)
|
190 |
+
|
191 |
+
print("Original positions = %s" % str(x))
|
192 |
+
print("Target positions = %s" % str(dx))
|
193 |
+
|
194 |
+
all_rel_pos_bias = []
|
195 |
+
|
196 |
+
for i in range(num_attn_heads):
|
197 |
+
z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
|
198 |
+
f = F.interpolate.interp2d(x, y, z, kind="cubic")
|
199 |
+
all_rel_pos_bias.append(torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
|
200 |
+
|
201 |
+
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
|
202 |
+
|
203 |
+
new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
|
204 |
+
state_dict[key] = new_rel_pos_bias
|
205 |
+
|
206 |
+
# interpolate position embedding
|
207 |
+
if "pos_embed" in state_dict:
|
208 |
+
pos_embed_checkpoint = state_dict["pos_embed"]
|
209 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
210 |
+
num_patches = model.visual.patch_embed.num_patches
|
211 |
+
num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
|
212 |
+
# height (== width) for the checkpoint position embedding
|
213 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
214 |
+
# height (== width) for the new position embedding
|
215 |
+
new_size = int(num_patches**0.5)
|
216 |
+
# class_token and dist_token are kept unchanged
|
217 |
+
if orig_size != new_size:
|
218 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
219 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
220 |
+
# only the position tokens are interpolated
|
221 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
222 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
223 |
+
pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False)
|
224 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
225 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
226 |
+
state_dict["pos_embed"] = new_pos_embed
|
227 |
+
|
228 |
+
patch_embed_proj = state_dict["patch_embed.proj.weight"]
|
229 |
+
patch_size = model.visual.patch_embed.patch_size
|
230 |
+
state_dict["patch_embed.proj.weight"] = torch.nn.functional.interpolate(patch_embed_proj.float(), size=patch_size, mode="bicubic", align_corners=False)
|
231 |
+
|
232 |
+
|
233 |
+
def freeze_batch_norm_2d(module, module_match={}, name=""):
|
234 |
+
"""
|
235 |
+
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
|
236 |
+
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
|
237 |
+
returned. Otherwise, the module is walked recursively and submodules are converted in place.
|
238 |
+
|
239 |
+
Args:
|
240 |
+
module (torch.nn.Module): Any PyTorch module.
|
241 |
+
module_match (dict): Dictionary of full module names to freeze (all if empty)
|
242 |
+
name (str): Full module name (prefix)
|
243 |
+
|
244 |
+
Returns:
|
245 |
+
torch.nn.Module: Resulting module
|
246 |
+
|
247 |
+
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
|
248 |
+
"""
|
249 |
+
res = module
|
250 |
+
is_match = True
|
251 |
+
if module_match:
|
252 |
+
is_match = name in module_match
|
253 |
+
if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
|
254 |
+
res = FrozenBatchNorm2d(module.num_features)
|
255 |
+
res.num_features = module.num_features
|
256 |
+
res.affine = module.affine
|
257 |
+
if module.affine:
|
258 |
+
res.weight.data = module.weight.data.clone().detach()
|
259 |
+
res.bias.data = module.bias.data.clone().detach()
|
260 |
+
res.running_mean.data = module.running_mean.data
|
261 |
+
res.running_var.data = module.running_var.data
|
262 |
+
res.eps = module.eps
|
263 |
+
else:
|
264 |
+
for child_name, child in module.named_children():
|
265 |
+
full_child_name = ".".join([name, child_name]) if name else child_name
|
266 |
+
new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
|
267 |
+
if new_child is not child:
|
268 |
+
res.add_module(child_name, new_child)
|
269 |
+
return res
|
270 |
+
|
271 |
+
|
272 |
+
# From PyTorch internals
|
273 |
+
def _ntuple(n):
|
274 |
+
def parse(x):
|
275 |
+
if isinstance(x, collections.abc.Iterable):
|
276 |
+
return x
|
277 |
+
return tuple(repeat(x, n))
|
278 |
+
|
279 |
+
return parse
|
280 |
+
|
281 |
+
|
282 |
+
to_1tuple = _ntuple(1)
|
283 |
+
to_2tuple = _ntuple(2)
|
284 |
+
to_3tuple = _ntuple(3)
|
285 |
+
to_4tuple = _ntuple(4)
|
286 |
+
to_ntuple = lambda n, x: _ntuple(n)(x)
|
287 |
+
|
288 |
+
|
289 |
+
def is_logging(args):
|
290 |
+
def is_global_master(args):
|
291 |
+
return args.rank == 0
|
292 |
+
|
293 |
+
def is_local_master(args):
|
294 |
+
return args.local_rank == 0
|
295 |
+
|
296 |
+
def is_master(args, local=False):
|
297 |
+
return is_local_master(args) if local else is_global_master(args)
|
298 |
+
|
299 |
+
return is_master
|
300 |
+
|
301 |
+
|
302 |
+
class AllGather(torch.autograd.Function):
|
303 |
+
"""An autograd function that performs allgather on a tensor.
|
304 |
+
Performs all_gather operation on the provided tensors.
|
305 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
306 |
+
"""
|
307 |
+
|
308 |
+
@staticmethod
|
309 |
+
def forward(ctx, tensor, rank, world_size):
|
310 |
+
tensors_gather = [torch.empty_like(tensor) for _ in range(world_size)]
|
311 |
+
torch.distributed.all_gather(tensors_gather, tensor)
|
312 |
+
ctx.rank = rank
|
313 |
+
ctx.batch_size = tensor.shape[0]
|
314 |
+
return torch.cat(tensors_gather, 0)
|
315 |
+
|
316 |
+
@staticmethod
|
317 |
+
def backward(ctx, grad_output):
|
318 |
+
return (grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)], None, None)
|
319 |
+
|
320 |
+
|
321 |
+
allgather = AllGather.apply
|
blip3o/model/multimodal_encoder/dev_eva_clip/eva_vit.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on EVA, BEIT, timm and DeiT code bases
|
2 |
+
# https://github.com/baaivision/EVA
|
3 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
4 |
+
# https://github.com/microsoft/unilm/tree/master/beit
|
5 |
+
# https://github.com/facebookresearch/deit/
|
6 |
+
# https://github.com/facebookresearch/dino
|
7 |
+
# --------------------------------------------------------'
|
8 |
+
# not tested yet
|
9 |
+
import math
|
10 |
+
from transformers import CLIPImageProcessor
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import torch.utils.checkpoint as checkpoint
|
16 |
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
17 |
+
from .eva_clip import create_model_and_transforms, get_model_config
|
18 |
+
import torch
|
19 |
+
import torchvision
|
20 |
+
import time
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
class EvaViTWrapper(nn.Module):
|
25 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
self.is_loaded = False
|
29 |
+
self.vision_tower_name = vision_tower
|
30 |
+
self.pretrained = args.vision_tower_pretrained
|
31 |
+
self.args = args
|
32 |
+
|
33 |
+
self.select_layer = args.mm_vision_select_layer
|
34 |
+
if self.select_layer < -1:
|
35 |
+
self.select_layer += 1
|
36 |
+
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
|
37 |
+
|
38 |
+
self.model_config = get_model_config(self.vision_tower_name)
|
39 |
+
|
40 |
+
if not delay_load:
|
41 |
+
print(f"Loading vision tower: {vision_tower}")
|
42 |
+
self.load_model()
|
43 |
+
elif getattr(args, "unfreeze_mm_vision_tower", False):
|
44 |
+
# TODO: better detector is needed.
|
45 |
+
print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
|
46 |
+
self.load_model()
|
47 |
+
elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
|
48 |
+
print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
|
49 |
+
self.load_model()
|
50 |
+
|
51 |
+
def load_model(self):
|
52 |
+
print(f"Loading: {self.vision_tower_name}")
|
53 |
+
print(f"Pretrained: {self.pretrained}")
|
54 |
+
time_start = time.time()
|
55 |
+
model, _, image_processor = create_model_and_transforms(self.vision_tower_name, self.pretrained, force_custom_clip=True, precision="fp16")
|
56 |
+
time_end = time.time()
|
57 |
+
print(f"Loaded: {self.vision_tower_name} in {time_end - time_start:.2f}s")
|
58 |
+
self.device = next(model.parameters()).device
|
59 |
+
self.dtype = next(model.parameters()).dtype
|
60 |
+
if self.device.type != "meta":
|
61 |
+
model = model.to("cuda")
|
62 |
+
self.vision_tower = model.visual
|
63 |
+
resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0]
|
64 |
+
normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0]
|
65 |
+
self.resize_transform_size = resize_transform.size
|
66 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(
|
67 |
+
"openai/clip-vit-large-patch14",
|
68 |
+
crop_size=resize_transform.size,
|
69 |
+
size={"shortest_edge": resize_transform.size},
|
70 |
+
image_mean=list(normalize_transform.mean),
|
71 |
+
image_std=list(normalize_transform.std),
|
72 |
+
)
|
73 |
+
print(f"Loaded image processor: {self.image_processor}")
|
74 |
+
self.vision_tower.requires_grad_(False)
|
75 |
+
self.is_loaded = True
|
76 |
+
|
77 |
+
def feature_select(self, image_features):
|
78 |
+
select_feature_type = self.select_feature
|
79 |
+
|
80 |
+
# if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
|
81 |
+
# select_every_k_layer = len(image_features) // 4
|
82 |
+
# image_features = torch.cat([image_features[i] for i in range(select_every_k_layer + self.select_layer, len(image_features), select_every_k_layer)], dim=-1)
|
83 |
+
# select_feature_type = select_feature_type.replace("slicefour_", "")
|
84 |
+
# elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]:
|
85 |
+
# select_layers = [-1, -4, -7, -10, 6]
|
86 |
+
# image_features = torch.cat([image_features[i] for i in select_layers], dim=-1)
|
87 |
+
# select_feature_type = select_feature_type.replace("slice_m25811_f6_", "")
|
88 |
+
# else:
|
89 |
+
# image_features = image_features[self.select_layer]
|
90 |
+
|
91 |
+
if select_feature_type == "patch":
|
92 |
+
image_features = image_features[:, 1:]
|
93 |
+
elif select_feature_type == "cls_patch":
|
94 |
+
image_features = image_features
|
95 |
+
else:
|
96 |
+
raise ValueError(f"Unexpected select feature: {select_feature_type}")
|
97 |
+
return image_features
|
98 |
+
|
99 |
+
def train(self, mode=True):
|
100 |
+
self.training = mode
|
101 |
+
|
102 |
+
if self.is_loaded:
|
103 |
+
self.vision_tower.eval()
|
104 |
+
|
105 |
+
def forward(self, images):
|
106 |
+
if type(images) is list:
|
107 |
+
image_features = []
|
108 |
+
for image in images:
|
109 |
+
image_features = self.vision_tower.forward_features(image.to(self.dtype), return_all_features=True)
|
110 |
+
image_features = self.feature_select(image_features).to(self.dtype)
|
111 |
+
image_features.append(image_features)
|
112 |
+
else:
|
113 |
+
image_features = self.vision_tower.forward_features(images.to(self.dtype), return_all_features=True)
|
114 |
+
image_features = self.feature_select(image_features).to(self.dtype)
|
115 |
+
|
116 |
+
return image_features
|
117 |
+
|
118 |
+
@property
|
119 |
+
def dummy_feature(self):
|
120 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
121 |
+
|
122 |
+
@property
|
123 |
+
def hidden_size(self):
|
124 |
+
return self.model_config["vision_cfg"]["width"]
|
125 |
+
|
126 |
+
@property
|
127 |
+
def num_patches(self):
|
128 |
+
return (self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]) ** 2
|
129 |
+
|
130 |
+
@property
|
131 |
+
def num_patches_per_side(self):
|
132 |
+
return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]
|
133 |
+
|
134 |
+
@property
|
135 |
+
def config(self):
|
136 |
+
return self.model_config
|
137 |
+
|
138 |
+
@property
|
139 |
+
def image_size(self):
|
140 |
+
return self.model_config["vision_cfg"]["image_size"]
|
blip3o/model/multimodal_encoder/eva_clip/eva_clip_encoder.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .eva_clip_processors import EvaClipImageTrainProcessor
|
5 |
+
from .eva_vit import EVAEncoderWrapper
|
6 |
+
from .factory import list_models, add_model_config, get_model_config
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
class EvaClipVisionTower(nn.Module):
|
11 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
12 |
+
super().__init__()
|
13 |
+
|
14 |
+
self.is_loaded = False
|
15 |
+
self.vision_tower_name = vision_tower
|
16 |
+
self.vision_tower_pretrained = args.vision_tower_pretrained
|
17 |
+
self.config = get_model_config(vision_tower)
|
18 |
+
|
19 |
+
|
20 |
+
if not delay_load:
|
21 |
+
print(f"Loading EVA ViT: {self.vision_tower_name}")
|
22 |
+
self.load_model()
|
23 |
+
elif getattr(args, "unfreeze_mm_vision_tower", False):
|
24 |
+
# TODO: better detector is needed.
|
25 |
+
print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
|
26 |
+
self.load_model()
|
27 |
+
elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
|
28 |
+
print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
|
29 |
+
self.load_model()
|
30 |
+
else:
|
31 |
+
self.cfg_only = self.config
|
32 |
+
|
33 |
+
|
34 |
+
def load_model(self, device_map=None):
|
35 |
+
print(f"Pretrained: {self.vision_tower_pretrained}")
|
36 |
+
self.image_processor = EvaClipImageTrainProcessor(self.config["vision_cfg"]["image_size"])
|
37 |
+
self.vision_tower = EVAEncoderWrapper(self.vision_tower_pretrained, self.config)
|
38 |
+
print(f"Loaded image processor: {self.image_processor}")
|
39 |
+
self.vision_tower.requires_grad_(False)
|
40 |
+
self.is_loaded = True
|
41 |
+
|
42 |
+
def forward(self, images):
|
43 |
+
if type(images) is list:
|
44 |
+
image_features = []
|
45 |
+
for image in images:
|
46 |
+
image_feature = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)).to(image.dtype)
|
47 |
+
image_features.append(image_feature)
|
48 |
+
else:
|
49 |
+
image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(images.dtype)
|
50 |
+
|
51 |
+
return image_features
|
52 |
+
|
53 |
+
@property
|
54 |
+
def dtype(self):
|
55 |
+
return self.vision_tower.dtype
|
56 |
+
|
57 |
+
@property
|
58 |
+
def device(self):
|
59 |
+
return self.vision_tower.device
|
60 |
+
|
61 |
+
@property
|
62 |
+
def hidden_size(self):
|
63 |
+
return self.config["vision_cfg"]["width"]
|
64 |
+
|
65 |
+
@property
|
66 |
+
def num_patches(self):
|
67 |
+
return (self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"]) ** 2
|
68 |
+
|
69 |
+
@property
|
70 |
+
def num_patches_per_side(self):
|
71 |
+
return self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"]
|
72 |
+
|
73 |
+
@property
|
74 |
+
def image_size(self):
|
75 |
+
return self.config["vision_cfg"]["image_size"]
|
blip3o/model/multimodal_encoder/eva_clip/eva_clip_processors.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP
|
3 |
+
"""
|
4 |
+
|
5 |
+
from torchvision import transforms
|
6 |
+
from torchvision.transforms.functional import InterpolationMode
|
7 |
+
from transformers.image_processing_utils import BatchFeature
|
8 |
+
from PIL import Image
|
9 |
+
from transformers.image_transforms import convert_to_rgb
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
class BaseProcessor:
|
13 |
+
def __init__(self):
|
14 |
+
self.transform = lambda x: x
|
15 |
+
return
|
16 |
+
|
17 |
+
def __call__(self, item):
|
18 |
+
return self.transform(item)
|
19 |
+
|
20 |
+
|
21 |
+
class EvaClipImageBaseProcessor(BaseProcessor):
|
22 |
+
def __init__(self, mean=None, std=None):
|
23 |
+
self.mean = (0.48145466, 0.4578275, 0.40821073) if mean is None else mean
|
24 |
+
self.std = (0.26862954, 0.26130258, 0.27577711) if std is None else std
|
25 |
+
self.normalize = transforms.Normalize(self.mean, self.std)
|
26 |
+
|
27 |
+
|
28 |
+
@property
|
29 |
+
def image_mean(self):
|
30 |
+
return self.mean
|
31 |
+
|
32 |
+
|
33 |
+
class EvaClipImageTrainProcessor(EvaClipImageBaseProcessor):
|
34 |
+
def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
|
35 |
+
super().__init__(mean=mean, std=std)
|
36 |
+
|
37 |
+
self.transform = transforms.Compose(
|
38 |
+
[
|
39 |
+
convert_to_rgb,
|
40 |
+
transforms.Resize(
|
41 |
+
image_size,
|
42 |
+
interpolation=InterpolationMode.BICUBIC,
|
43 |
+
),
|
44 |
+
transforms.CenterCrop(image_size),
|
45 |
+
transforms.ToTensor(),
|
46 |
+
self.normalize,
|
47 |
+
]
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
self.image_size = image_size
|
52 |
+
|
53 |
+
def preprocess(self, images, return_tensors):
|
54 |
+
if isinstance(images, Image.Image):
|
55 |
+
images = [images]
|
56 |
+
else:
|
57 |
+
assert isinstance(images, list)
|
58 |
+
|
59 |
+
transformed_images = [self.transform(image).numpy() for image in images]
|
60 |
+
data = {"pixel_values": transformed_images}
|
61 |
+
|
62 |
+
|
63 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
64 |
+
|
65 |
+
def __call__(self, item):
|
66 |
+
return self.transform(item)
|
67 |
+
|
68 |
+
@property
|
69 |
+
def crop_size(self):
|
70 |
+
return {"height": self.image_size, "width": self.image_size}
|
71 |
+
|
72 |
+
@property
|
73 |
+
def size(self):
|
74 |
+
return {"shortest_edge": self.image_size}
|
blip3o/model/multimodal_encoder/eva_clip/eva_vit.py
ADDED
@@ -0,0 +1,762 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP
|
3 |
+
"""
|
4 |
+
|
5 |
+
from math import pi
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from einops import rearrange, repeat
|
9 |
+
import logging
|
10 |
+
|
11 |
+
from huggingface_hub import snapshot_download
|
12 |
+
cache_dir = snapshot_download(repo_id="jiuhai/eva_clip_vision_tower")
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
def broadcat(tensors, dim=-1):
|
18 |
+
num_tensors = len(tensors)
|
19 |
+
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
20 |
+
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
21 |
+
shape_len = list(shape_lens)[0]
|
22 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
23 |
+
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
24 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
25 |
+
assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), "invalid dimensions for broadcastable concatentation"
|
26 |
+
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
27 |
+
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
28 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
29 |
+
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
30 |
+
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
31 |
+
return torch.cat(tensors, dim=dim)
|
32 |
+
|
33 |
+
|
34 |
+
def rotate_half(x):
|
35 |
+
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
36 |
+
x1, x2 = x.unbind(dim=-1)
|
37 |
+
x = torch.stack((-x2, x1), dim=-1)
|
38 |
+
return rearrange(x, "... d r -> ... (d r)")
|
39 |
+
|
40 |
+
|
41 |
+
class VisionRotaryEmbeddingFast(nn.Module):
|
42 |
+
def __init__(self, dim, pt_seq_len, ft_seq_len=None, custom_freqs=None, freqs_for="lang", theta=10000, max_freq=10, num_freqs=1, patch_dropout=0.0):
|
43 |
+
super().__init__()
|
44 |
+
if custom_freqs:
|
45 |
+
freqs = custom_freqs
|
46 |
+
elif freqs_for == "lang":
|
47 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
48 |
+
elif freqs_for == "pixel":
|
49 |
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
50 |
+
elif freqs_for == "constant":
|
51 |
+
freqs = torch.ones(num_freqs).float()
|
52 |
+
else:
|
53 |
+
raise ValueError(f"unknown modality {freqs_for}")
|
54 |
+
|
55 |
+
if ft_seq_len is None:
|
56 |
+
ft_seq_len = pt_seq_len
|
57 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
58 |
+
|
59 |
+
freqs = torch.einsum("..., f -> ... f", t, freqs)
|
60 |
+
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
61 |
+
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
|
62 |
+
|
63 |
+
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
|
64 |
+
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
|
65 |
+
|
66 |
+
self.patch_dropout = patch_dropout
|
67 |
+
|
68 |
+
self.register_buffer("freqs_cos", freqs_cos)
|
69 |
+
self.register_buffer("freqs_sin", freqs_sin)
|
70 |
+
|
71 |
+
logging.info(f"Shape of rope freq: {self.freqs_cos.shape}")
|
72 |
+
|
73 |
+
def forward(self, t, patch_indices_keep=None):
|
74 |
+
if patch_indices_keep is not None:
|
75 |
+
batch = t.size()[0]
|
76 |
+
batch_indices = torch.arange(batch)
|
77 |
+
batch_indices = batch_indices[..., None]
|
78 |
+
|
79 |
+
freqs_cos = repeat(self.freqs_cos, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
|
80 |
+
freqs_sin = repeat(self.freqs_sin, "i j -> n i m j", n=t.shape[0], m=t.shape[1])
|
81 |
+
|
82 |
+
freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
|
83 |
+
freqs_cos = rearrange(freqs_cos, "n i m j -> n m i j")
|
84 |
+
freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
|
85 |
+
freqs_sin = rearrange(freqs_sin, "n i m j -> n m i j")
|
86 |
+
|
87 |
+
return t * freqs_cos + rotate_half(t) * freqs_sin
|
88 |
+
|
89 |
+
return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
|
90 |
+
|
91 |
+
|
92 |
+
class LayerNorm(nn.LayerNorm):
|
93 |
+
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
|
94 |
+
|
95 |
+
def forward(self, x: torch.Tensor):
|
96 |
+
orig_type = x.dtype
|
97 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
98 |
+
return x.to(orig_type)
|
99 |
+
|
100 |
+
|
101 |
+
class PatchDropout(nn.Module):
|
102 |
+
"""
|
103 |
+
https://arxiv.org/abs/2212.00794
|
104 |
+
"""
|
105 |
+
|
106 |
+
def __init__(self, prob, exclude_first_token=True):
|
107 |
+
super().__init__()
|
108 |
+
assert 0 <= prob < 1.
|
109 |
+
self.prob = prob
|
110 |
+
self.exclude_first_token = exclude_first_token # exclude CLS token
|
111 |
+
print(f"os.getenv('RoPE')={os.getenv('RoPE')}")
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
if not self.training or self.prob == 0.:
|
115 |
+
return x
|
116 |
+
|
117 |
+
if self.exclude_first_token:
|
118 |
+
cls_tokens, x = x[:, :1], x[:, 1:]
|
119 |
+
else:
|
120 |
+
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
|
121 |
+
|
122 |
+
batch = x.size()[0]
|
123 |
+
num_tokens = x.size()[1]
|
124 |
+
|
125 |
+
batch_indices = torch.arange(batch)
|
126 |
+
batch_indices = batch_indices[..., None]
|
127 |
+
|
128 |
+
keep_prob = 1 - self.prob
|
129 |
+
num_patches_keep = max(1, int(num_tokens * keep_prob))
|
130 |
+
|
131 |
+
rand = torch.randn(batch, num_tokens)
|
132 |
+
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
133 |
+
|
134 |
+
x = x[batch_indices, patch_indices_keep]
|
135 |
+
|
136 |
+
if self.exclude_first_token:
|
137 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
138 |
+
|
139 |
+
if self.training and os.getenv('RoPE') == '1':
|
140 |
+
return x, patch_indices_keep
|
141 |
+
|
142 |
+
return x
|
143 |
+
|
144 |
+
|
145 |
+
# --------------------------------------------------------
|
146 |
+
# Adapted from https://github.com/microsoft/unilm/tree/master/beit
|
147 |
+
# --------------------------------------------------------
|
148 |
+
import math
|
149 |
+
import os
|
150 |
+
import torch.nn as nn
|
151 |
+
import torch.nn.functional as F
|
152 |
+
|
153 |
+
try:
|
154 |
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
155 |
+
except:
|
156 |
+
from timm.layers import drop_path, to_2tuple, trunc_normal_
|
157 |
+
|
158 |
+
if os.getenv("ENV_TYPE") == "deepspeed":
|
159 |
+
try:
|
160 |
+
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
|
161 |
+
except:
|
162 |
+
from torch.utils.checkpoint import checkpoint
|
163 |
+
else:
|
164 |
+
from torch.utils.checkpoint import checkpoint
|
165 |
+
|
166 |
+
try:
|
167 |
+
import xformers.ops as xops
|
168 |
+
except ImportError:
|
169 |
+
xops = None
|
170 |
+
# print("Please 'pip install xformers'")
|
171 |
+
|
172 |
+
|
173 |
+
class DropPath(nn.Module):
|
174 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
175 |
+
"""
|
176 |
+
def __init__(self, drop_prob=None):
|
177 |
+
super(DropPath, self).__init__()
|
178 |
+
self.drop_prob = drop_prob
|
179 |
+
|
180 |
+
def forward(self, x):
|
181 |
+
return drop_path(x, self.drop_prob, self.training)
|
182 |
+
|
183 |
+
def extra_repr(self) -> str:
|
184 |
+
return 'p={}'.format(self.drop_prob)
|
185 |
+
|
186 |
+
|
187 |
+
class Mlp(nn.Module):
|
188 |
+
def __init__(
|
189 |
+
self,
|
190 |
+
in_features,
|
191 |
+
hidden_features=None,
|
192 |
+
out_features=None,
|
193 |
+
act_layer=nn.GELU,
|
194 |
+
norm_layer=nn.LayerNorm,
|
195 |
+
drop=0.,
|
196 |
+
subln=False,
|
197 |
+
|
198 |
+
):
|
199 |
+
super().__init__()
|
200 |
+
out_features = out_features or in_features
|
201 |
+
hidden_features = hidden_features or in_features
|
202 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
203 |
+
self.act = act_layer()
|
204 |
+
|
205 |
+
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
|
206 |
+
|
207 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
208 |
+
self.drop = nn.Dropout(drop)
|
209 |
+
|
210 |
+
def forward(self, x):
|
211 |
+
x = self.fc1(x)
|
212 |
+
x = self.act(x)
|
213 |
+
# x = self.drop(x)
|
214 |
+
# commit this for the orignal BERT implement
|
215 |
+
x = self.ffn_ln(x)
|
216 |
+
|
217 |
+
x = self.fc2(x)
|
218 |
+
x = self.drop(x)
|
219 |
+
return x
|
220 |
+
|
221 |
+
|
222 |
+
class SwiGLU(nn.Module):
|
223 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
|
224 |
+
norm_layer=nn.LayerNorm, subln=False):
|
225 |
+
super().__init__()
|
226 |
+
out_features = out_features or in_features
|
227 |
+
hidden_features = hidden_features or in_features
|
228 |
+
|
229 |
+
self.w1 = nn.Linear(in_features, hidden_features)
|
230 |
+
self.w2 = nn.Linear(in_features, hidden_features)
|
231 |
+
|
232 |
+
self.act = act_layer()
|
233 |
+
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
|
234 |
+
self.w3 = nn.Linear(hidden_features, out_features)
|
235 |
+
|
236 |
+
self.drop = nn.Dropout(drop)
|
237 |
+
|
238 |
+
def forward(self, x):
|
239 |
+
x1 = self.w1(x)
|
240 |
+
x2 = self.w2(x)
|
241 |
+
hidden = self.act(x1) * x2
|
242 |
+
x = self.ffn_ln(hidden)
|
243 |
+
x = self.w3(x)
|
244 |
+
x = self.drop(x)
|
245 |
+
return x
|
246 |
+
|
247 |
+
|
248 |
+
class Attention(nn.Module):
|
249 |
+
def __init__(
|
250 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
251 |
+
proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
|
252 |
+
super().__init__()
|
253 |
+
self.num_heads = num_heads
|
254 |
+
head_dim = dim // num_heads
|
255 |
+
if attn_head_dim is not None:
|
256 |
+
head_dim = attn_head_dim
|
257 |
+
all_head_dim = head_dim * self.num_heads
|
258 |
+
self.scale = qk_scale or head_dim ** -0.5
|
259 |
+
|
260 |
+
self.subln = subln
|
261 |
+
if self.subln:
|
262 |
+
self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
|
263 |
+
self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
|
264 |
+
self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
|
265 |
+
else:
|
266 |
+
if qkv_bias:
|
267 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=True)
|
268 |
+
else:
|
269 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
270 |
+
|
271 |
+
# if qkv_bias:
|
272 |
+
# self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
273 |
+
# self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
274 |
+
# else:
|
275 |
+
# self.q_bias = None
|
276 |
+
# self.v_bias = None
|
277 |
+
|
278 |
+
self.window_size = None
|
279 |
+
self.relative_position_bias_table = None
|
280 |
+
self.relative_position_index = None
|
281 |
+
|
282 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
283 |
+
self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
|
284 |
+
# self.proj = nn.Linear(all_head_dim, all_head_dim)
|
285 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
286 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
287 |
+
self.xattn = xattn
|
288 |
+
self.xattn_drop = attn_drop
|
289 |
+
|
290 |
+
self.rope = rope
|
291 |
+
|
292 |
+
def forward(self, x, rel_pos_bias=None, attn_mask=None):
|
293 |
+
B, N, C = x.shape
|
294 |
+
if self.subln:
|
295 |
+
q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
|
296 |
+
k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
|
297 |
+
v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
|
298 |
+
|
299 |
+
q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
|
300 |
+
k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
301 |
+
v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
302 |
+
else:
|
303 |
+
|
304 |
+
# qkv_bias = None
|
305 |
+
# if self.q_bias is not None:
|
306 |
+
# qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
307 |
+
|
308 |
+
# qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
309 |
+
|
310 |
+
qkv = self.qkv(x)
|
311 |
+
|
312 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
|
313 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
314 |
+
|
315 |
+
if self.rope:
|
316 |
+
q_t = q[:, :, 1:, :]
|
317 |
+
ro_q_t = self.rope(q_t)
|
318 |
+
q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
|
319 |
+
|
320 |
+
k_t = k[:, :, 1:, :]
|
321 |
+
ro_k_t = self.rope(k_t)
|
322 |
+
k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
|
323 |
+
|
324 |
+
if self.xattn:
|
325 |
+
q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
|
326 |
+
k = k.permute(0, 2, 1, 3)
|
327 |
+
v = v.permute(0, 2, 1, 3)
|
328 |
+
|
329 |
+
x = xops.memory_efficient_attention(
|
330 |
+
q, k, v,
|
331 |
+
p=self.xattn_drop,
|
332 |
+
scale=self.scale,
|
333 |
+
)
|
334 |
+
x = x.reshape(B, N, -1)
|
335 |
+
x = self.inner_attn_ln(x)
|
336 |
+
x = self.proj(x)
|
337 |
+
x = self.proj_drop(x)
|
338 |
+
else:
|
339 |
+
q = q * self.scale
|
340 |
+
attn = (q @ k.transpose(-2, -1))
|
341 |
+
|
342 |
+
if self.relative_position_bias_table is not None:
|
343 |
+
relative_position_bias = \
|
344 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
345 |
+
self.window_size[0] * self.window_size[1] + 1,
|
346 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
347 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
348 |
+
attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
|
349 |
+
|
350 |
+
if rel_pos_bias is not None:
|
351 |
+
attn = attn + rel_pos_bias.type_as(attn)
|
352 |
+
|
353 |
+
if attn_mask is not None:
|
354 |
+
attn_mask = attn_mask.bool()
|
355 |
+
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
|
356 |
+
|
357 |
+
attn = attn.softmax(dim=-1)
|
358 |
+
attn = self.attn_drop(attn)
|
359 |
+
|
360 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
361 |
+
x = self.inner_attn_ln(x)
|
362 |
+
x = self.proj(x)
|
363 |
+
x = self.proj_drop(x)
|
364 |
+
return x
|
365 |
+
|
366 |
+
|
367 |
+
class Block(nn.Module):
|
368 |
+
|
369 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
370 |
+
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
371 |
+
window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
|
372 |
+
subln=False, naiveswiglu=False):
|
373 |
+
super().__init__()
|
374 |
+
self.norm1 = norm_layer(dim)
|
375 |
+
self.attn = Attention(
|
376 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
377 |
+
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
|
378 |
+
xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
|
379 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
380 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
381 |
+
self.norm2 = norm_layer(dim)
|
382 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
383 |
+
|
384 |
+
if naiveswiglu:
|
385 |
+
self.mlp = SwiGLU(
|
386 |
+
in_features=dim,
|
387 |
+
hidden_features=mlp_hidden_dim,
|
388 |
+
subln=subln,
|
389 |
+
norm_layer=norm_layer,
|
390 |
+
)
|
391 |
+
else:
|
392 |
+
self.mlp = Mlp(
|
393 |
+
in_features=dim,
|
394 |
+
hidden_features=mlp_hidden_dim,
|
395 |
+
act_layer=act_layer,
|
396 |
+
subln=subln,
|
397 |
+
drop=drop
|
398 |
+
)
|
399 |
+
|
400 |
+
if init_values is not None and init_values > 0:
|
401 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
402 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
403 |
+
else:
|
404 |
+
self.gamma_1, self.gamma_2 = None, None
|
405 |
+
|
406 |
+
self.postnorm = postnorm
|
407 |
+
|
408 |
+
def forward(self, x, rel_pos_bias=None, attn_mask=None):
|
409 |
+
if self.gamma_1 is None:
|
410 |
+
if self.postnorm:
|
411 |
+
x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
|
412 |
+
x = x + self.drop_path(self.norm2(self.mlp(x)))
|
413 |
+
else:
|
414 |
+
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
|
415 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
416 |
+
else:
|
417 |
+
if self.postnorm:
|
418 |
+
x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
|
419 |
+
x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
|
420 |
+
else:
|
421 |
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
|
422 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
423 |
+
return x
|
424 |
+
|
425 |
+
|
426 |
+
class PatchEmbed(nn.Module):
|
427 |
+
""" Image to Patch Embedding
|
428 |
+
"""
|
429 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
430 |
+
super().__init__()
|
431 |
+
img_size = to_2tuple(img_size)
|
432 |
+
patch_size = to_2tuple(patch_size)
|
433 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
434 |
+
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
435 |
+
self.img_size = img_size
|
436 |
+
self.patch_size = patch_size
|
437 |
+
self.num_patches = num_patches
|
438 |
+
|
439 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
440 |
+
|
441 |
+
def forward(self, x, **kwargs):
|
442 |
+
B, C, H, W = x.shape
|
443 |
+
# FIXME look at relaxing size constraints
|
444 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
445 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
446 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
447 |
+
return x
|
448 |
+
|
449 |
+
|
450 |
+
class RelativePositionBias(nn.Module):
|
451 |
+
|
452 |
+
def __init__(self, window_size, num_heads):
|
453 |
+
super().__init__()
|
454 |
+
self.window_size = window_size
|
455 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
456 |
+
self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
457 |
+
# cls to token & token 2 cls & cls to cls
|
458 |
+
|
459 |
+
# get pair-wise relative position index for each token inside the window
|
460 |
+
coords_h = torch.arange(window_size[0])
|
461 |
+
coords_w = torch.arange(window_size[1])
|
462 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
463 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
464 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
465 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
466 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
467 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
468 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
469 |
+
relative_position_index = torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
470 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
471 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
472 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
473 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
474 |
+
|
475 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
476 |
+
|
477 |
+
def forward(self):
|
478 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
479 |
+
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
480 |
+
|
481 |
+
|
482 |
+
class EVAVisionTransformer(nn.Module):
|
483 |
+
"""Vision Transformer with support for patch or hybrid CNN input stage"""
|
484 |
+
|
485 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
486 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
487 |
+
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
|
488 |
+
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,
|
489 |
+
use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,
|
490 |
+
pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False,
|
491 |
+
):
|
492 |
+
super().__init__()
|
493 |
+
self.image_size = img_size
|
494 |
+
# self.num_classes = num_classes
|
495 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
496 |
+
|
497 |
+
self.patch_embed = PatchEmbed(
|
498 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
499 |
+
num_patches = self.patch_embed.num_patches
|
500 |
+
|
501 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
502 |
+
if use_abs_pos_emb:
|
503 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
504 |
+
else:
|
505 |
+
self.pos_embed = None
|
506 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
507 |
+
|
508 |
+
self.rel_pos_bias = None
|
509 |
+
self.rope = None
|
510 |
+
|
511 |
+
self.naiveswiglu = naiveswiglu
|
512 |
+
|
513 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
514 |
+
self.use_rel_pos_bias = use_rel_pos_bias
|
515 |
+
self.blocks = nn.ModuleList([
|
516 |
+
Block(
|
517 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
518 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
519 |
+
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
|
520 |
+
xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)
|
521 |
+
for i in range(depth)])
|
522 |
+
|
523 |
+
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
|
524 |
+
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
|
525 |
+
|
526 |
+
self.grad_checkpointing = grad_checkpointing
|
527 |
+
|
528 |
+
def fix_init_weight(self):
|
529 |
+
def rescale(param, layer_id):
|
530 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
531 |
+
|
532 |
+
for layer_id, layer in enumerate(self.blocks):
|
533 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
534 |
+
if self.naiveswiglu:
|
535 |
+
rescale(layer.mlp.w3.weight.data, layer_id + 1)
|
536 |
+
else:
|
537 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
538 |
+
|
539 |
+
def get_cast_dtype(self) -> torch.dtype:
|
540 |
+
return self.blocks[0].mlp.fc2.weight.dtype
|
541 |
+
|
542 |
+
def _init_weights(self, m):
|
543 |
+
if isinstance(m, nn.Linear):
|
544 |
+
trunc_normal_(m.weight, std=0.02)
|
545 |
+
if m.bias is not None:
|
546 |
+
nn.init.constant_(m.bias, 0)
|
547 |
+
elif isinstance(m, nn.LayerNorm):
|
548 |
+
nn.init.constant_(m.bias, 0)
|
549 |
+
nn.init.constant_(m.weight, 1.0)
|
550 |
+
|
551 |
+
def get_num_layers(self):
|
552 |
+
return len(self.blocks)
|
553 |
+
|
554 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
555 |
+
assert unlocked_groups == 0, "partial locking not currently supported for this model"
|
556 |
+
for param in self.parameters():
|
557 |
+
param.requires_grad = False
|
558 |
+
|
559 |
+
@torch.jit.ignore
|
560 |
+
def set_grad_checkpointing(self, enable=True):
|
561 |
+
self.grad_checkpointing = enable
|
562 |
+
|
563 |
+
@torch.jit.ignore
|
564 |
+
def no_weight_decay(self):
|
565 |
+
return {"pos_embed", "cls_token"}
|
566 |
+
|
567 |
+
def get_classifier(self):
|
568 |
+
return self.head
|
569 |
+
|
570 |
+
def reset_classifier(self, num_classes, global_pool=""):
|
571 |
+
self.num_classes = num_classes
|
572 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
573 |
+
|
574 |
+
def forward_features(self, x):
|
575 |
+
x = self.patch_embed(x)
|
576 |
+
batch_size, seq_len, _ = x.size()
|
577 |
+
|
578 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
579 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
580 |
+
if self.pos_embed is not None:
|
581 |
+
x = x + self.pos_embed
|
582 |
+
x = self.pos_drop(x)
|
583 |
+
|
584 |
+
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
585 |
+
if os.getenv('RoPE') == '1':
|
586 |
+
if self.training and not isinstance(self.patch_dropout, nn.Identity):
|
587 |
+
x, patch_indices_keep = self.patch_dropout(x)
|
588 |
+
self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
|
589 |
+
else:
|
590 |
+
self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
|
591 |
+
x = self.patch_dropout(x)
|
592 |
+
else:
|
593 |
+
x = self.patch_dropout(x)
|
594 |
+
|
595 |
+
rel_pos_bias = None
|
596 |
+
|
597 |
+
for blk in self.blocks:
|
598 |
+
if self.grad_checkpointing:
|
599 |
+
x = checkpoint(blk, x, (rel_pos_bias,))
|
600 |
+
else:
|
601 |
+
x = blk(x, rel_pos_bias=rel_pos_bias)
|
602 |
+
|
603 |
+
return x
|
604 |
+
|
605 |
+
def forward(self, x, return_all_features=False):
|
606 |
+
|
607 |
+
"""
|
608 |
+
:return:
|
609 |
+
forward_features function returns raw features of ViT,
|
610 |
+
forward with return_all_features returns normalized features of ViT
|
611 |
+
:param x:
|
612 |
+
:param return_all_features:
|
613 |
+
"""
|
614 |
+
|
615 |
+
features = self.forward_features(x) # [B, n_patch, C]
|
616 |
+
return features
|
617 |
+
|
618 |
+
|
619 |
+
def load_state_dict(checkpoint_path: str, map_location: str = "cpu", model_key: str = "model|module|state_dict", is_openai: bool = False, skip_list: list = []):
|
620 |
+
if is_openai:
|
621 |
+
model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
|
622 |
+
state_dict = model.state_dict()
|
623 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
624 |
+
state_dict.pop(key, None)
|
625 |
+
else:
|
626 |
+
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
627 |
+
for mk in model_key.split("|"):
|
628 |
+
if isinstance(checkpoint, dict) and mk in checkpoint:
|
629 |
+
state_dict = checkpoint[mk]
|
630 |
+
break
|
631 |
+
else:
|
632 |
+
state_dict = checkpoint
|
633 |
+
if next(iter(state_dict.items()))[0].startswith("module"):
|
634 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
635 |
+
|
636 |
+
for k in skip_list:
|
637 |
+
if k in list(state_dict.keys()):
|
638 |
+
logging.info(f"Removing key {k} from pretrained checkpoint")
|
639 |
+
del state_dict[k]
|
640 |
+
|
641 |
+
if os.getenv("RoPE") == "1":
|
642 |
+
for k in list(state_dict.keys()):
|
643 |
+
if "freqs_cos" in k or "freqs_sin" in k:
|
644 |
+
del state_dict[k]
|
645 |
+
return state_dict
|
646 |
+
|
647 |
+
|
648 |
+
def load_clip_visual_state_dict(checkpoint_path: str, map_location: str = "cpu", is_openai: bool = False, skip_list: list = []):
|
649 |
+
state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
|
650 |
+
# for k in list(state_dict.keys()):
|
651 |
+
# if not k.startswith("visual."):
|
652 |
+
# del state_dict[k]
|
653 |
+
# for k in list(state_dict.keys()):
|
654 |
+
# if k.startswith("visual."):
|
655 |
+
# new_k = k[7:]
|
656 |
+
# state_dict[new_k] = state_dict[k]
|
657 |
+
# del state_dict[k]
|
658 |
+
return state_dict
|
659 |
+
|
660 |
+
|
661 |
+
from dataclasses import dataclass
|
662 |
+
from typing import Optional, Tuple, Union
|
663 |
+
|
664 |
+
try:
|
665 |
+
from apex.normalization import FusedLayerNorm
|
666 |
+
except:
|
667 |
+
FusedLayerNorm = LayerNorm
|
668 |
+
# print("Please build and install Nvidia apex package with option '--cuda_ext' according to https://github.com/NVIDIA/apex#from-source .")
|
669 |
+
|
670 |
+
|
671 |
+
@dataclass
|
672 |
+
class CLIPVisionCfg:
|
673 |
+
layers: Union[Tuple[int, int, int, int], int] = 12
|
674 |
+
width: int = 768
|
675 |
+
head_width: int = 64
|
676 |
+
mlp_ratio: float = 4.0
|
677 |
+
patch_size: int = 16
|
678 |
+
image_size: Union[Tuple[int, int], int] = 224
|
679 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
680 |
+
patch_dropout: float = 0.0 # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
|
681 |
+
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
|
682 |
+
drop_path_rate: Optional[float] = None # drop path rate
|
683 |
+
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
|
684 |
+
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
|
685 |
+
timm_pool: str = "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
|
686 |
+
timm_proj: str = "linear" # linear projection for timm model output ('linear', 'mlp', '')
|
687 |
+
timm_proj_bias: bool = False # enable bias final projection
|
688 |
+
eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
|
689 |
+
qkv_bias: bool = True
|
690 |
+
fusedLN: bool = False
|
691 |
+
xattn: bool = False
|
692 |
+
postnorm: bool = False
|
693 |
+
rope: bool = False
|
694 |
+
pt_hw_seq_len: int = 16 # 224/14
|
695 |
+
intp_freq: bool = False
|
696 |
+
naiveswiglu: bool = False
|
697 |
+
subln: bool = False
|
698 |
+
|
699 |
+
|
700 |
+
def create_norm_layer_factory(use_fused_ln, eps=1e-6):
|
701 |
+
# Otherwise, use the standard LayerNorm
|
702 |
+
return lambda num_features: nn.LayerNorm(num_features, eps=eps)
|
703 |
+
|
704 |
+
|
705 |
+
def _build_vision_tower(vision_tower_path: str, embed_dim: int, vision_cfg: CLIPVisionCfg, **kwargs):
|
706 |
+
if isinstance(vision_cfg, dict):
|
707 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg)
|
708 |
+
|
709 |
+
if vision_cfg.eva_model_name:
|
710 |
+
vision_heads = vision_cfg.width // vision_cfg.head_width
|
711 |
+
# Determine the appropriate norm layer factory based on the configuration
|
712 |
+
norm_layer_factory = create_norm_layer_factory(vision_cfg.fusedLN, eps=1e-6)
|
713 |
+
|
714 |
+
# breakpoint()
|
715 |
+
visual = EVAVisionTransformer(
|
716 |
+
img_size=vision_cfg.image_size,
|
717 |
+
patch_size=vision_cfg.patch_size,
|
718 |
+
num_classes=embed_dim,
|
719 |
+
use_mean_pooling=vision_cfg.global_average_pool, # False
|
720 |
+
init_values=vision_cfg.ls_init_value,
|
721 |
+
patch_dropout=vision_cfg.patch_dropout,
|
722 |
+
embed_dim=vision_cfg.width,
|
723 |
+
depth=vision_cfg.layers,
|
724 |
+
num_heads=vision_heads,
|
725 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
726 |
+
qkv_bias=vision_cfg.qkv_bias,
|
727 |
+
drop_path_rate=vision_cfg.drop_path_rate,
|
728 |
+
norm_layer=norm_layer_factory,
|
729 |
+
xattn=vision_cfg.xattn,
|
730 |
+
rope=vision_cfg.rope,
|
731 |
+
postnorm=vision_cfg.postnorm,
|
732 |
+
pt_hw_seq_len=vision_cfg.pt_hw_seq_len, # 224/14
|
733 |
+
intp_freq=vision_cfg.intp_freq,
|
734 |
+
naiveswiglu=vision_cfg.naiveswiglu,
|
735 |
+
subln=vision_cfg.subln,
|
736 |
+
)
|
737 |
+
# breakpoint()
|
738 |
+
state_dict = load_clip_visual_state_dict(vision_tower_path)
|
739 |
+
incompatible_keys = visual.load_state_dict(state_dict, strict=False)
|
740 |
+
print("EVA-CLIP incompatible_keys:", incompatible_keys)
|
741 |
+
|
742 |
+
return visual
|
743 |
+
|
744 |
+
|
745 |
+
class EVAEncoderWrapper(nn.Module):
|
746 |
+
def __init__(self, vision_tower_pretrained, config):
|
747 |
+
super(EVAEncoderWrapper, self).__init__()
|
748 |
+
self.config = config
|
749 |
+
self.config["vision_tower_path"] = os.path.join(cache_dir, "pytorch_model.bin")
|
750 |
+
self.model = _build_vision_tower(**self.config)
|
751 |
+
|
752 |
+
def forward(self, image, **kwargs):
|
753 |
+
encode = self.model(image, return_all_features=True)[:, 1:, :] # remove the CLS token
|
754 |
+
return encode
|
755 |
+
|
756 |
+
@property
|
757 |
+
def dtype(self):
|
758 |
+
return list(self.parameters())[-1].dtype
|
759 |
+
|
760 |
+
@property
|
761 |
+
def device(self):
|
762 |
+
return list(self.parameters())[-1].device
|
blip3o/model/multimodal_encoder/eva_clip/factory.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import pathlib
|
5 |
+
import re
|
6 |
+
from copy import deepcopy
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Optional, Tuple, Union, Dict, Any
|
9 |
+
import torch
|
10 |
+
|
11 |
+
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
|
12 |
+
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
|
13 |
+
|
14 |
+
|
15 |
+
def _natural_key(string_):
|
16 |
+
return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
|
17 |
+
|
18 |
+
|
19 |
+
def _rescan_model_configs():
|
20 |
+
global _MODEL_CONFIGS
|
21 |
+
|
22 |
+
config_ext = (".json",)
|
23 |
+
config_files = []
|
24 |
+
for config_path in _MODEL_CONFIG_PATHS:
|
25 |
+
if config_path.is_file() and config_path.suffix in config_ext:
|
26 |
+
config_files.append(config_path)
|
27 |
+
elif config_path.is_dir():
|
28 |
+
for ext in config_ext:
|
29 |
+
config_files.extend(config_path.glob(f"*{ext}"))
|
30 |
+
for cf in config_files:
|
31 |
+
with open(cf, "r", encoding="utf8") as f:
|
32 |
+
model_cfg = json.load(f)
|
33 |
+
if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")):
|
34 |
+
_MODEL_CONFIGS[cf.stem] = model_cfg
|
35 |
+
|
36 |
+
_MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])))
|
37 |
+
|
38 |
+
|
39 |
+
_rescan_model_configs() # initial populate of model config registry
|
40 |
+
|
41 |
+
|
42 |
+
def list_models():
|
43 |
+
"""enumerate available model architectures based on config files"""
|
44 |
+
return list(_MODEL_CONFIGS.keys())
|
45 |
+
|
46 |
+
|
47 |
+
def add_model_config(path):
|
48 |
+
"""add model config path or file and update registry"""
|
49 |
+
if not isinstance(path, Path):
|
50 |
+
path = Path(path)
|
51 |
+
_MODEL_CONFIG_PATHS.append(path)
|
52 |
+
_rescan_model_configs()
|
53 |
+
|
54 |
+
|
55 |
+
def get_model_config(model_name):
|
56 |
+
if model_name in _MODEL_CONFIGS:
|
57 |
+
return deepcopy(_MODEL_CONFIGS[model_name])
|
58 |
+
else:
|
59 |
+
return None
|
blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-18B.json
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1536,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 48,
|
6 |
+
"width": 5120,
|
7 |
+
"head_width": 128,
|
8 |
+
"mlp_ratio": 5,
|
9 |
+
"patch_size": 14,
|
10 |
+
"eva_model_name": "eva-clip-18b-14-x",
|
11 |
+
"drop_path_rate": 0,
|
12 |
+
"qkv_bias": false,
|
13 |
+
"xattn": true,
|
14 |
+
"postnorm": true,
|
15 |
+
"fusedLN": false,
|
16 |
+
"use_rms_norm": true
|
17 |
+
},
|
18 |
+
"text_cfg": {
|
19 |
+
"context_length": 77,
|
20 |
+
"vocab_size": 49408,
|
21 |
+
"width": 1280,
|
22 |
+
"heads": 20,
|
23 |
+
"layers": 32,
|
24 |
+
"xattn": false,
|
25 |
+
"fusedLN": false
|
26 |
+
}
|
27 |
+
}
|
blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B-plus.json
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1280,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 448,
|
5 |
+
"layers": 32,
|
6 |
+
"width": 4096,
|
7 |
+
"head_width": 128,
|
8 |
+
"mlp_ratio": 5,
|
9 |
+
"patch_size": 14,
|
10 |
+
"eva_model_name": "eva-clip-8b-14-plus-x",
|
11 |
+
"drop_path_rate": 0,
|
12 |
+
"qkv_bias": false,
|
13 |
+
"xattn": true,
|
14 |
+
"postnorm": false,
|
15 |
+
"fusedLN": false,
|
16 |
+
"use_rms_norm": true
|
17 |
+
},
|
18 |
+
"text_cfg": {
|
19 |
+
"context_length": 77,
|
20 |
+
"vocab_size": 49408,
|
21 |
+
"width": 1280,
|
22 |
+
"heads": 20,
|
23 |
+
"layers": 32,
|
24 |
+
"xattn": false,
|
25 |
+
"fusedLN": false
|
26 |
+
}
|
27 |
+
}
|
blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B.json
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1280,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 32,
|
6 |
+
"width": 4096,
|
7 |
+
"head_width": 128,
|
8 |
+
"mlp_ratio": 5,
|
9 |
+
"patch_size": 14,
|
10 |
+
"eva_model_name": "eva-clip-8b-14-x",
|
11 |
+
"drop_path_rate": 0,
|
12 |
+
"qkv_bias": false,
|
13 |
+
"xattn": true,
|
14 |
+
"postnorm": false,
|
15 |
+
"fusedLN": false,
|
16 |
+
"use_rms_norm": true
|
17 |
+
},
|
18 |
+
"text_cfg": {
|
19 |
+
"context_length": 77,
|
20 |
+
"vocab_size": 49408,
|
21 |
+
"width": 1280,
|
22 |
+
"heads": 20,
|
23 |
+
"layers": 32,
|
24 |
+
"xattn": false,
|
25 |
+
"fusedLN": false
|
26 |
+
}
|
27 |
+
}
|
blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-B-16.json
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 768,
|
7 |
+
"patch_size": 16,
|
8 |
+
"eva_model_name": "eva-clip-b-16",
|
9 |
+
"ls_init_value": 0.1,
|
10 |
+
"drop_path_rate": 0.0
|
11 |
+
},
|
12 |
+
"text_cfg": {
|
13 |
+
"context_length": 77,
|
14 |
+
"vocab_size": 49408,
|
15 |
+
"width": 512,
|
16 |
+
"heads": 8,
|
17 |
+
"layers": 12
|
18 |
+
}
|
19 |
+
}
|
blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 40,
|
6 |
+
"width": 1408,
|
7 |
+
"head_width": 88,
|
8 |
+
"mlp_ratio": 4.3637,
|
9 |
+
"patch_size": 14,
|
10 |
+
"eva_model_name": "eva-clip-g-14-x",
|
11 |
+
"drop_path_rate": 0,
|
12 |
+
"xattn": true,
|
13 |
+
"fusedLN": true
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 1024,
|
19 |
+
"heads": 16,
|
20 |
+
"layers": 24,
|
21 |
+
"xattn": false,
|
22 |
+
"fusedLN": true
|
23 |
+
}
|
24 |
+
}
|
blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 40,
|
6 |
+
"width": 1408,
|
7 |
+
"head_width": 88,
|
8 |
+
"mlp_ratio": 4.3637,
|
9 |
+
"patch_size": 14,
|
10 |
+
"eva_model_name": "eva-clip-g-14-x",
|
11 |
+
"drop_path_rate": 0.4,
|
12 |
+
"xattn": true,
|
13 |
+
"fusedLN": true
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 768,
|
19 |
+
"heads": 12,
|
20 |
+
"layers": 12,
|
21 |
+
"xattn": false,
|
22 |
+
"fusedLN": true
|
23 |
+
}
|
24 |
+
}
|
blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-B-16.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 768,
|
7 |
+
"head_width": 64,
|
8 |
+
"patch_size": 16,
|
9 |
+
"mlp_ratio": 2.6667,
|
10 |
+
"eva_model_name": "eva-clip-b-16-X",
|
11 |
+
"drop_path_rate": 0.0,
|
12 |
+
"xattn": true,
|
13 |
+
"fusedLN": true,
|
14 |
+
"rope": true,
|
15 |
+
"pt_hw_seq_len": 16,
|
16 |
+
"intp_freq": true,
|
17 |
+
"naiveswiglu": true,
|
18 |
+
"subln": true
|
19 |
+
},
|
20 |
+
"text_cfg": {
|
21 |
+
"context_length": 77,
|
22 |
+
"vocab_size": 49408,
|
23 |
+
"width": 512,
|
24 |
+
"heads": 8,
|
25 |
+
"layers": 12,
|
26 |
+
"xattn": true,
|
27 |
+
"fusedLN": true
|
28 |
+
}
|
29 |
+
}
|
blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14-336.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 336,
|
5 |
+
"layers": 24,
|
6 |
+
"width": 1024,
|
7 |
+
"drop_path_rate": 0,
|
8 |
+
"head_width": 64,
|
9 |
+
"mlp_ratio": 2.6667,
|
10 |
+
"patch_size": 14,
|
11 |
+
"eva_model_name": "eva-clip-l-14-336",
|
12 |
+
"xattn": true,
|
13 |
+
"fusedLN": true,
|
14 |
+
"rope": true,
|
15 |
+
"pt_hw_seq_len": 16,
|
16 |
+
"intp_freq": true,
|
17 |
+
"naiveswiglu": true,
|
18 |
+
"subln": true
|
19 |
+
},
|
20 |
+
"text_cfg": {
|
21 |
+
"context_length": 77,
|
22 |
+
"vocab_size": 49408,
|
23 |
+
"width": 768,
|
24 |
+
"heads": 12,
|
25 |
+
"layers": 12,
|
26 |
+
"xattn": false,
|
27 |
+
"fusedLN": true
|
28 |
+
}
|
29 |
+
}
|
blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 24,
|
6 |
+
"width": 1024,
|
7 |
+
"drop_path_rate": 0,
|
8 |
+
"head_width": 64,
|
9 |
+
"mlp_ratio": 2.6667,
|
10 |
+
"patch_size": 14,
|
11 |
+
"eva_model_name": "eva-clip-l-14",
|
12 |
+
"xattn": true,
|
13 |
+
"fusedLN": true,
|
14 |
+
"rope": true,
|
15 |
+
"pt_hw_seq_len": 16,
|
16 |
+
"intp_freq": true,
|
17 |
+
"naiveswiglu": true,
|
18 |
+
"subln": true
|
19 |
+
},
|
20 |
+
"text_cfg": {
|
21 |
+
"context_length": 77,
|
22 |
+
"vocab_size": 49408,
|
23 |
+
"width": 768,
|
24 |
+
"heads": 12,
|
25 |
+
"layers": 12,
|
26 |
+
"xattn": false,
|
27 |
+
"fusedLN": true
|
28 |
+
}
|
29 |
+
}
|
blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 64,
|
6 |
+
"width": 1792,
|
7 |
+
"head_width": 112,
|
8 |
+
"mlp_ratio": 8.571428571428571,
|
9 |
+
"patch_size": 14,
|
10 |
+
"eva_model_name": "eva-clip-4b-14-x",
|
11 |
+
"drop_path_rate": 0,
|
12 |
+
"xattn": true,
|
13 |
+
"postnorm": true,
|
14 |
+
"fusedLN": true
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 1280,
|
20 |
+
"heads": 20,
|
21 |
+
"layers": 32,
|
22 |
+
"xattn": false,
|
23 |
+
"fusedLN": true
|
24 |
+
}
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
|
blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 64,
|
6 |
+
"width": 1792,
|
7 |
+
"head_width": 112,
|
8 |
+
"mlp_ratio": 8.571428571428571,
|
9 |
+
"patch_size": 14,
|
10 |
+
"eva_model_name": "eva-clip-4b-14-x",
|
11 |
+
"drop_path_rate": 0,
|
12 |
+
"xattn": true,
|
13 |
+
"postnorm": true,
|
14 |
+
"fusedLN": true
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 1024,
|
20 |
+
"heads": 16,
|
21 |
+
"layers": 24,
|
22 |
+
"xattn": false,
|
23 |
+
"fusedLN": true
|
24 |
+
}
|
25 |
+
}
|