diff --git a/blip3o/.DS_Store b/blip3o/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..d888fa4793ccbf9a93c7e701afab0501abd48f56 Binary files /dev/null and b/blip3o/.DS_Store differ diff --git a/blip3o/__init__.py b/blip3o/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..227940e931e3502b3f4b2ada4f01bf1246b53e7d --- /dev/null +++ b/blip3o/__init__.py @@ -0,0 +1 @@ +from .model import blip3oLlamaForCausalLM diff --git a/blip3o/constants.py b/blip3o/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..528af91aa7cf562e40ca377fd288c3e392b2c991 --- /dev/null +++ b/blip3o/constants.py @@ -0,0 +1,26 @@ +CONTROLLER_HEART_BEAT_EXPIRATION = 30 +WORKER_HEART_BEAT_INTERVAL = 15 + +LOGDIR = "." + +# Model Constants +IGNORE_INDEX = -100 +# IMAGE_TOKEN_INDEX = -200 + +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_IM_START_TOKEN = "[IMG]" +DEFAULT_IM_END_TOKEN = "[/IMG]" + + +# IMAGE_TOKEN_IDX = 32002 +# DEFAULT_IM_START_TOKEN_IDX = 32000 +# DEFAULT_IM_END_TOKEN_IDX = 32001 + +IMAGE_TOKEN_IDX = 151667 +DEFAULT_IM_START_TOKEN_IDX = 128257 +DEFAULT_IM_END_TOKEN_IDX = 128258 +UND_IMAGE_TOKEN_IDX = 151655 +# N_QUERY = 729 + +DEFAULT_IMAGE_PATCH_TOKEN = "" +IMAGE_PLACEHOLDER = "" diff --git a/blip3o/conversation.py b/blip3o/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..918a15353733feedf48292d6a4c37e5eac6d6923 --- /dev/null +++ b/blip3o/conversation.py @@ -0,0 +1,479 @@ +import dataclasses +from enum import auto, Enum +from typing import List, Tuple +import base64 +from io import BytesIO +from PIL import Image + + +class SeparatorStyle(Enum): + """Different separator style.""" + SINGLE = auto() + TWO = auto() + MPT = auto() + PLAIN = auto() + LLAMA_2 = auto() + CHATML = auto() + QWEN = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "###" + sep2: str = None + version: str = "Unknown" + + skip_next: bool = False + + def get_prompt(self): + messages = self.messages + if len(messages) > 0 and type(messages[0][1]) is tuple: + messages = self.messages.copy() + init_role, init_msg = messages[0].copy() + init_msg = init_msg[0] + if "mmtag" in self.version: + init_msg = init_msg.replace("", "").strip() + messages[0] = (init_role, init_msg) + messages.insert(0, (self.roles[0], "")) + messages.insert(1, (self.roles[1], "Received.")) + elif not init_msg.startswith(""): + init_msg = init_msg.replace("", "").strip() + messages[0] = (init_role, "\n" + init_msg) + else: + messages[0] = (init_role, init_msg) + + if self.sep_style == SeparatorStyle.SINGLE: + ret = self.system + self.sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + self.sep + else: + ret += role + ":" + + elif self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + + elif self.sep_style == SeparatorStyle.CHATML: + ret = "" if self.system == "" else self.system + self.sep + "\n" + for role, message in messages: + if message: + if type(message) is tuple: + message, images, _ = message + message = "" * len(images) + message + ret += role + "\n" + message + self.sep + "\n" + else: + ret += role + "\n" + return ret + + elif self.sep_style == SeparatorStyle.LLAMA_3: + if self.tokenizer is None: + raise ValueError("Llama 3 tokenizer is not available. Make sure you have the necessary permissions.") + chat_template_messages = [{"role": "system", "content": self.system}] + for role, message in messages: + if message: + if type(message) is tuple: + message, images = message + message = "" * len(images) + message + chat_template_messages.append({"role": role, "content": message}) + + # print(chat_template_messages) + return self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False, add_generation_prompt=True) + # ret = "" if self.system == "" else self.system + self.sep + "\n" + # for role, message in messages: + # if message: + # if type(message) is tuple: + # message, images = message + # message = "" * len(images) + message + # ret += role + "\n" + message + self.sep + "\n" + # else: + # ret += role + "\n" + # return ret + + elif self.sep_style == SeparatorStyle.MPT: + ret = self.system + self.sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + self.sep + else: + ret += role + + elif self.sep_style == SeparatorStyle.GEMMA: + ret = "" + for i, (role, message) in enumerate(messages): + assert role == self.roles[i % 2], "Conversation should alternate user/assistant/user/assistant/..." + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + self.sep + else: + ret += role + + elif self.sep_style == SeparatorStyle.LLAMA_2: + wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg + wrap_inst = lambda msg: f"[INST] {msg} [/INST]" + ret = "" + + for i, (role, message) in enumerate(messages): + if i == 0: + assert message, "first message should not be none" + assert role == self.roles[0], "first message should come from user" + if message: + if type(message) is tuple: + message, _, _ = message + if i == 0: + message = wrap_sys(self.system) + message + if i % 2 == 0: + message = wrap_inst(message) + ret += self.sep + message + else: + ret += " " + message + " " + self.sep2 + else: + ret += "" + ret = ret.lstrip(self.sep) + + elif self.sep_style == SeparatorStyle.PLAIN: + seps = [self.sep, self.sep2] + ret = self.system + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += message + seps[i % 2] + else: + ret += "" + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + return ret + + def append_message(self, role, message): + self.messages.append([role, message]) + + def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672): + if image_process_mode == "Pad": + def expand2square(pil_img, background_color=(122, 116, 104)): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + image = expand2square(image) + elif image_process_mode in ["Default", "Crop"]: + pass + elif image_process_mode == "Resize": + image = image.resize((336, 336)) + else: + raise ValueError(f"Invalid image_process_mode: {image_process_mode}") + if max(image.size) > max_len: + max_hw, min_hw = max(image.size), min(image.size) + aspect_ratio = max_hw / min_hw + shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + longest_edge = int(shortest_edge * aspect_ratio) + W, H = image.size + if H > W: + H, W = longest_edge, shortest_edge + else: + H, W = shortest_edge, longest_edge + image = image.resize((W, H)) + if return_pil: + return image + else: + buffered = BytesIO() + image.save(buffered, format=image_format) + img_b64_str = base64.b64encode(buffered.getvalue()).decode() + return img_b64_str + + def get_images(self, return_pil=False): + images = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + if type(msg) is tuple: + msg, image, image_process_mode = msg + image = self.process_image(image, image_process_mode, return_pil=return_pil) + images.append(image) + return images + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + if type(msg) is tuple: + msg, image, image_process_mode = msg + img_b64_str = self.process_image( + image, "Default", return_pil=False, + image_format='JPEG') + img_str = f'user upload image' + msg = img_str + msg.replace('', '').strip() + ret.append([msg, None]) + else: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + version=self.version) + + def dict(self): + if len(self.get_images()) > 0: + return { + "system": self.system, + "roles": self.roles, + "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + + +conv_vicuna_v0 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ("Human", "What are the key differences between renewable and non-renewable energy sources?"), + ("Assistant", + "Renewable energy sources are those that can be replenished naturally in a relatively " + "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " + "Non-renewable energy sources, on the other hand, are finite and will eventually be " + "depleted, such as coal, oil, and natural gas. Here are some key differences between " + "renewable and non-renewable energy sources:\n" + "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " + "energy sources are finite and will eventually run out.\n" + "2. Environmental impact: Renewable energy sources have a much lower environmental impact " + "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " + "and other negative effects.\n" + "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " + "have lower operational costs than non-renewable sources.\n" + "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " + "locations than non-renewable sources.\n" + "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " + "situations and needs, while non-renewable sources are more rigid and inflexible.\n" + "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " + "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +conv_vicuna_v1 = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_llama_2 = Conversation( + system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", + roles=("USER", "ASSISTANT"), + version="llama_v2", + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA_2, + sep="", + sep2="", +) + + +conv_blip3o_llama_2 = Conversation( + system="You are a helpful language and vision assistant. " + "You are able to understand the visual content that the user provides, " + "and assist the user with a variety of tasks using natural language.", + roles=("USER", "ASSISTANT"), + version="llama_v2", + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA_2, + sep="", + sep2="", +) + +conv_mpt = Conversation( + system="""<|im_start|>system +A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", +) + +conv_blip3o_plain = Conversation( + system="", + roles=("", ""), + messages=( + ), + offset=0, + sep_style=SeparatorStyle.PLAIN, + sep="\n", +) + +conv_blip3o_v0 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ), + offset=0, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +conv_blip3o_v0_mmtag = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." + "The visual content will be provided with the following format: visual content.", + roles=("Human", "Assistant"), + messages=( + ), + offset=0, + sep_style=SeparatorStyle.SINGLE, + sep="###", + version="v0_mmtag", +) + +conv_blip3o_v1 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_blip3o_v1_mmtag = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." + "The visual content will be provided with the following format: visual content.", + roles=("USER", "ASSISTANT"), + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", + version="v1_mmtag", +) + +conv_mistral_instruct = Conversation( + system="", + roles=("USER", "ASSISTANT"), + version="llama_v2", + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA_2, + sep="", + sep2="", +) + +conv_chatml_direct = Conversation( + system="""<|im_start|>system +Answer the questions.""", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", +) + +conv_llama3 = Conversation( + system="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.""", + roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"), + version="llama3", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|eot_id|>", +) + +conv_qwen = Conversation( + system="""<|im_start|>system +You are a helpful assistant.""", + roles=("<|im_start|>user", "<|im_start|>assistant"), + version="qwen", + messages=[], + offset=0, + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", +) + + +default_conversation = conv_llama3 +conv_templates = { + "default": conv_vicuna_v0, + "v0": conv_vicuna_v0, + "v1": conv_vicuna_v1, + "vicuna_v1": conv_vicuna_v1, + "llama_2": conv_llama_2, + "mistral_instruct": conv_mistral_instruct, + "chatml_direct": conv_chatml_direct, + "mistral_direct": conv_chatml_direct, + + "plain": conv_blip3o_plain, + "v0_plain": conv_blip3o_plain, + "blip3o_v0": conv_blip3o_v0, + "v0_mmtag": conv_blip3o_v0_mmtag, + "blip3o_v1": conv_blip3o_v1, + "v1_mmtag": conv_blip3o_v1_mmtag, + "blip3o_llama_2": conv_blip3o_llama_2, + "llama3": conv_llama3, + "qwen": conv_qwen, + + "mpt": conv_mpt, +} + +if __name__ == "__main__": + print(default_conversation.get_prompt()) \ No newline at end of file diff --git a/blip3o/mm_utils.py b/blip3o/mm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d3ebdb4d06eab8ccb353c4756269a3e7c58170da --- /dev/null +++ b/blip3o/mm_utils.py @@ -0,0 +1,247 @@ +from PIL import Image +from io import BytesIO +import base64 +import torch +import math +import ast + +from transformers import StoppingCriteria +from blip3o.constants import IMAGE_TOKEN_IDX + + +def select_best_resolution(original_size, possible_resolutions): + """ + Selects the best resolution from a list of possible resolutions based on the original size. + + Args: + original_size (tuple): The original size of the image in the format (width, height). + possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. + + Returns: + tuple: The best fit resolution in the format (width, height). + """ + original_width, original_height = original_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float('inf') + + for width, height in possible_resolutions: + scale = min(width / original_width, height / original_height) + downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) + effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) + wasted_resolution = (width * height) - effective_resolution + + if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (width, height) + + return best_fit + + +def resize_and_pad_image(image, target_resolution): + """ + Resize and pad an image to a target resolution while maintaining aspect ratio. + + Args: + image (PIL.Image.Image): The input image. + target_resolution (tuple): The target resolution (width, height) of the image. + + Returns: + PIL.Image.Image: The resized and padded image. + """ + original_width, original_height = image.size + target_width, target_height = target_resolution + + scale_w = target_width / original_width + scale_h = target_height / original_height + + if scale_w < scale_h: + new_width = target_width + new_height = min(math.ceil(original_height * scale_w), target_height) + else: + new_height = target_height + new_width = min(math.ceil(original_width * scale_h), target_width) + + # Resize the image + resized_image = image.resize((new_width, new_height)) + + new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0)) + paste_x = (target_width - new_width) // 2 + paste_y = (target_height - new_height) // 2 + new_image.paste(resized_image, (paste_x, paste_y)) + + return new_image + + +def divide_to_patches(image, patch_size): + """ + Divides an image into patches of a specified size. + + Args: + image (PIL.Image.Image): The input image. + patch_size (int): The size of each patch. + + Returns: + list: A list of PIL.Image.Image objects representing the patches. + """ + patches = [] + width, height = image.size + for i in range(0, height, patch_size): + for j in range(0, width, patch_size): + box = (j, i, j + patch_size, i + patch_size) + patch = image.crop(box) + patches.append(patch) + + return patches + + +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + + Args: + image_size (tuple): The size of the input image in the format (width, height). + grid_pinpoints (str): A string representation of a list of possible resolutions. + patch_size (int): The size of each image patch. + + Returns: + tuple: The shape of the image patch grid in the format (width, height). + """ + if type(grid_pinpoints) is list: + possible_resolutions = grid_pinpoints + else: + possible_resolutions = ast.literal_eval(grid_pinpoints) + width, height = select_best_resolution(image_size, possible_resolutions) + return width // patch_size, height // patch_size + + +def process_anyres_image(image, processor, grid_pinpoints): + """ + Process an image with variable resolutions. + + Args: + image (PIL.Image.Image): The input image to be processed. + processor: The image processor object. + grid_pinpoints (str): A string representation of a list of possible resolutions. + + Returns: + torch.Tensor: A tensor containing the processed image patches. + """ + if type(grid_pinpoints) is list: + possible_resolutions = grid_pinpoints + else: + possible_resolutions = ast.literal_eval(grid_pinpoints) + best_resolution = select_best_resolution(image.size, possible_resolutions) + image_padded = resize_and_pad_image(image, best_resolution) + + patches = divide_to_patches(image_padded, processor.crop_size['height']) + + image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge'])) + + image_patches = [image_original_resize] + patches + image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] + for image_patch in image_patches] + return torch.stack(image_patches, dim=0) + + +def load_image_from_base64(image): + return Image.open(BytesIO(base64.b64decode(image))) + + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +def process_images(images, image_processor, model_cfg): + image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) + new_images = [] + if image_aspect_ratio == 'pad': + for image in images: + image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) + image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + new_images.append(image) + elif image_aspect_ratio == "anyres": + for image in images: + image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints) + new_images.append(image) + else: + return image_processor(images, return_tensors='pt')['pixel_values'] + if all(x.shape == new_images[0].shape for x in new_images): + new_images = torch.stack(new_images, dim=0) + return new_images + + +def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_IDX, return_tensors=None): + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] + + def insert_separator(X, sep): + return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): + input_ids.extend(x[offset:]) + + if return_tensors is not None: + if return_tensors == 'pt': + return torch.tensor(input_ids, dtype=torch.long) + raise ValueError(f'Unsupported tensor type: {return_tensors}') + return input_ids + + +def get_model_name_from_path(model_path): + model_path = model_path.strip("/") + model_paths = model_path.split("/") + if model_paths[-1].startswith('checkpoint-'): + return model_paths[-2] + "_" + model_paths[-1] + else: + return model_paths[-1] + +class KeywordsStoppingCriteria(StoppingCriteria): + def __init__(self, keywords, tokenizer, input_ids): + self.keywords = keywords + self.keyword_ids = [] + self.max_keyword_len = 0 + for keyword in keywords: + cur_keyword_ids = tokenizer(keyword).input_ids + if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: + cur_keyword_ids = cur_keyword_ids[1:] + if len(cur_keyword_ids) > self.max_keyword_len: + self.max_keyword_len = len(cur_keyword_ids) + self.keyword_ids.append(torch.tensor(cur_keyword_ids)) + self.tokenizer = tokenizer + self.start_len = input_ids.shape[1] + + def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) + self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] + for keyword_id in self.keyword_ids: + truncated_output_ids = output_ids[0, -keyword_id.shape[0]:] + if torch.equal(truncated_output_ids, keyword_id): + return True + outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] + for keyword in self.keywords: + if keyword in outputs: + return True + return False + + def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + outputs = [] + for i in range(output_ids.shape[0]): + outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) + return all(outputs) diff --git a/blip3o/model/__init__.py b/blip3o/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..21a35e2269a6fed6ab8ba0a4928daa3d97dbcf47 --- /dev/null +++ b/blip3o/model/__init__.py @@ -0,0 +1,3 @@ +from .language_model.blip3o_llama import blip3oLlamaForCausalLM, blip3oConfig +from .language_model.blip3o_qwen import blip3oQwenForCausalLM, blip3oQwenConfig + diff --git a/blip3o/model/apply_delta.py b/blip3o/model/apply_delta.py new file mode 100644 index 0000000000000000000000000000000000000000..cce34f7658b9475e0b268408c5777402ec47d9b5 --- /dev/null +++ b/blip3o/model/apply_delta.py @@ -0,0 +1,48 @@ +""" +Usage: +python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta +""" +import argparse + +import torch +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM +from blip3o import blip3oLlamaForCausalLM + + +def apply_delta(base_model_path, target_model_path, delta_path): + print("Loading base model") + base = AutoModelForCausalLM.from_pretrained( + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + + print("Loading delta") + delta = blip3oLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) + + print("Applying delta") + for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): + if name not in base.state_dict(): + assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' + continue + if param.data.shape == base.state_dict()[name].shape: + param.data += base.state_dict()[name] + else: + assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ + f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' + bparam = base.state_dict()[name] + param.data[:bparam.shape[0], :bparam.shape[1]] += bparam + + print("Saving target model") + delta.save_pretrained(target_model_path) + delta_tokenizer.save_pretrained(target_model_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--delta-path", type=str, required=True) + + args = parser.parse_args() + + apply_delta(args.base_model_path, args.target_model_path, args.delta_path) diff --git a/blip3o/model/blip3o_arch.py b/blip3o/model/blip3o_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..2c2b32903ad441656a5231ed94a8bf85897f6c76 --- /dev/null +++ b/blip3o/model/blip3o_arch.py @@ -0,0 +1,415 @@ +from abc import ABC, abstractmethod + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .multimodal_encoder.builder import build_vision_tower, build_gen_vision_tower, build_dit +from .multimodal_projector.builder import build_vision_projector, build_down_projector, build_gen_vision_projector + +from blip3o.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_IDX, DEFAULT_IM_START_TOKEN_IDX, DEFAULT_IM_END_TOKEN_IDX, UND_IMAGE_TOKEN_IDX + + + +class blip3oMetaModel: + + def __init__(self, config): + super(blip3oMetaModel, self).__init__(config) + + if hasattr(config, "mm_vision_tower"): + # self.vision_tower = build_vision_tower(config, delay_load=True) + # self.mm_projector = build_vision_projector(config) + self.down_projector = build_down_projector(config) + + if 'unpad' in getattr(config, 'mm_patch_merge_type', ''): + self.image_newline = nn.Parameter( + torch.empty(config.hidden_size, dtype=self.dtype) + ) + + + if hasattr(config, "gen_vision_tower"): + self.gen_vision_tower = build_gen_vision_tower(config, delay_load=True) + # self.gen_projector = build_gen_vision_projector(config) + self.latent_queries = nn.Parameter(torch.randn(1, config.n_query, config.hidden_size)) + print(f" latent query size {self.latent_queries.shape}") + + if 'unpad' in getattr(config, 'mm_patch_merge_type', ''): + self.image_newline = nn.Parameter( + torch.empty(config.hidden_size, dtype=self.dtype) + ) + + self.dit, self.vae, self.noise_scheduler = build_dit(config) + + + # def get_vision_tower(self): + # vision_tower = getattr(self, 'vision_tower', None) + # if type(vision_tower) is list: + # vision_tower = vision_tower[0] + # return vision_tower + + + def get_gen_vision_tower(self): + gen_vision_tower = getattr(self, 'gen_vision_tower', None) + if type(gen_vision_tower) is list: + gen_vision_tower = gen_vision_tower[0] + return gen_vision_tower + + + def initialize_vision_modules(self, model_args, fsdp=None): + gen_vision_tower = model_args.gen_vision_tower + + mm_vision_select_layer = model_args.mm_vision_select_layer + mm_vision_select_feature = model_args.mm_vision_select_feature + + pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter + pretrain_gen_mlp_adapter = model_args.pretrain_gen_mlp_adapter + + mm_patch_merge_type = model_args.mm_patch_merge_type + + self.config.gen_vision_tower = gen_vision_tower + self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "") + + + + if getattr(self, 'dit', None) is None: + print("random initiation the DiT !!!") + self.dit, self.vae, self.noise_scheduler = build_dit(model_args) + else: + print("DiT load from checkpoint!!!") + for p in self.dit.parameters(): + p.requires_grad = True + + + if self.get_gen_vision_tower() is None: + gen_vision_tower = build_gen_vision_tower(model_args) + + if fsdp is not None and len(fsdp) > 0: + self.gen_vision_tower = [gen_vision_tower] + else: + self.gen_vision_tower = gen_vision_tower + else: + if fsdp is not None and len(fsdp) > 0: + gen_vision_tower = self.gen_vision_tower[0] + else: + gen_vision_tower = self.gen_vision_tower + gen_vision_tower.load_model() + + + self.config.use_mm_proj = True + self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') + # self.config.gen_projector_type = getattr(model_args, 'gen_projector_type', 'linear') + + + self.config.gen_hidden_size = gen_vision_tower.hidden_size + + self.config.mm_vision_select_layer = mm_vision_select_layer + self.config.mm_vision_select_feature = mm_vision_select_feature + self.config.mm_patch_merge_type = mm_patch_merge_type + self.config.n_query = model_args.n_query + self.config.gen_pooling = model_args.gen_pooling + + # if getattr(self, 'mm_projector', None) is None: + # print("random initiation the mm_project !!!") + # self.mm_projector = build_vision_projector(self.config) + + # if 'unpad' in mm_patch_merge_type: + # embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype)) + # self.image_newline = nn.Parameter( + # torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std + # ) + # else: + # # In case it is frozen by LoRA + # for p in self.mm_projector.parameters(): + # p.requires_grad = True + + + + if getattr(self, 'down_projector', None) is None: + print("random initiation the down_projector !!!") + self.down_projector = build_down_projector(self.config) + else: + # In case it is frozen by LoRA + for p in self.down_projector.parameters(): + p.requires_grad = True + + if getattr(self, 'latent_queries', None) is None: + print("random initiation the latent_queries !!!") + self.latent_queries = nn.Parameter(torch.randn(1, self.config.n_query, self.config.hidden_size)) + else: + print("latent_queries load from checkpoint!!!") + self.latent_queries.requires_grad = True + + + if pretrain_mm_mlp_adapter is not None: + mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') + def get_w(weights, keyword): + return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} + + # self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) + + + +def unpad_image(tensor, original_size): + """ + Unpads a PyTorch tensor of a padded and resized image. + + Args: + tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. + original_size (tuple): The original size of PIL image (width, height). + + Returns: + torch.Tensor: The unpadded image tensor. + """ + original_width, original_height = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding:current_height - padding, :] + else: + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding:current_width - padding] + + return unpadded_tensor + + +class blip3oMetaForCausalLM(ABC): + + @abstractmethod + def get_model(self): + pass + + def get_vision_tower(self): + return self.get_model().get_vision_tower() + + def get_gen_vision_tower(self): + return self.get_model().get_gen_vision_tower() + + def encode_image(self, images): + # breakpoint() + gen_vision_tower = self.get_gen_vision_tower() + device = gen_vision_tower.device + images = images.to(device) + prompt_image_embeds = gen_vision_tower(images) + if 'early' in self.get_gen_pooling(): + prompt_image_embeds = self.pool_img(prompt_image_embeds) + num_img, _, c = prompt_image_embeds.shape + # prompt_image_embeds = prompt_image_embeds.contiguous().view(-1, c) + + # ------------- compute similarity ------- + all_dist = 0 + count = 0 + for i in range(2, prompt_image_embeds.shape[1]-1): + diff = (prompt_image_embeds[:,i,:].unsqueeze(1) - prompt_image_embeds[:,:i,:]) + dist = torch.sqrt(diff.square().sum(-1)).min().item() + all_dist+=dist + count+=1 + all_dist /= count + # self.dist = all_dist + # print(self.dist) + + return prompt_image_embeds + + def get_mm_projector(self): + return self.get_model().mm_projector + + def get_gen_projector(self): + return None + + + def get_n_query(self): + return self.get_model().config.n_query + + def get_gen_pooling(self): + return self.get_model().config.gen_pooling + + def pool_img(self, image_features): + num_img, n, c = image_features.shape + gen_pooling = self.get_gen_pooling() + # n_query = self.get_n_query() + stride = int(gen_pooling.split('_')[-1]) + sqrt_n = int(n**0.5) + image_features = image_features.permute(0, 2, 1).view(num_img, c, sqrt_n, sqrt_n) + image_features = F.avg_pool2d(image_features, kernel_size=(stride, stride), stride=stride) + # image_features = image_features.view(num_img, c, -1).permute(0,2,1).contiguous() + return image_features + + def get_sigmas(self, timesteps, device, n_dim=4, dtype=torch.float32): + sigmas = self.get_model().noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = self.get_model().noise_scheduler.timesteps.to(device=device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + def mask_drop(self, latents, drop_prob=0.1): + if drop_prob <= 0: + return latents + mask = torch.bernoulli(torch.zeros(latents.shape[0], device=latents.device, dtype=latents.dtype) + drop_prob) + while len(mask.shape) < len(latents.shape): + mask = mask.unsqueeze(-1) + mask = 1 - mask # need to flip 0 <-> 1 + return latents * mask + + def prepare_inputs_labels_for_multimodal( + self, input_ids, position_ids, attention_mask, past_key_values, labels, + gen_images, und_images, grid_thw, i_s_pos, image_sizes=None + ): + pad_ids = 128256 + vision_tower = self.visual + gen_vision_tower = self.get_gen_vision_tower() + if (gen_images is None and und_images is None) or input_ids.shape[1] == 1: + return input_ids, position_ids, attention_mask, past_key_values, None, labels, None, None, None + + + + + prompt_image_embeds = gen_vision_tower(gen_images) # TODO: check dimension + + if 'early' in self.get_gen_pooling(): + prompt_image_embeds = self.pool_img(prompt_image_embeds) + target_image_embeds = torch.clone(prompt_image_embeds).detach() + latent_queries = self.get_model().latent_queries.repeat(input_ids.shape[0], 1, 1) + H = latent_queries.shape[-1] + latent_queries = latent_queries.contiguous().view(-1, H) + + + + + # if not gen_images is None: + # prompt_image_embeds = gen_vision_tower(gen_images) # TODO: check dimension + # if 'early' in self.get_gen_pooling(): + # prompt_image_embeds = self.pool_img(prompt_image_embeds) + # # num_img, _, c = prompt_image_embeds.shape # [batch, 729, 1152] + # # prompt_image_embeds = prompt_image_embeds.contiguous().view(-1, c) + # target_image_embeds = torch.clone(prompt_image_embeds).detach() + # # prompt_image_embeds = gen_projector(prompt_image_embeds) + # latent_queries = self.get_model().latent_queries.repeat(input_ids.shape[0], 1, 1) + # H = latent_queries.shape[-1] + # latent_queries = latent_queries.contiguous().view(-1, H) + # else: + # target_image_embeds = None + # num_img = und_images.shape[0] + # dummy = torch.zeros(num_img, 3, 448, 448 , dtype=und_images.dtype, device=und_images.device) # TODO + # temp = gen_vision_tower(dummy)[:,:729,:] + # num_img, _, c = temp.shape + # temp = temp.contiguous().view(-1, c) * 1e-20 + # # temp = gen_projector(temp) * 1e-9 + # latent_queries = self.get_model().latent_queries.repeat(input_ids.shape[0], 1, 1) + # H = latent_queries.shape[-1] + # latent_queries = latent_queries.contiguous().view(-1, H) + + + if not und_images is None: + und_image_embeds = vision_tower(und_images, grid_thw=grid_thw) + # _, c = und_image_embeds.shape + # batch_size = und_images.shape[0] + # und_image_embeds = und_image_embeds.view(batch_size, -1, c) + # und_image_embeds = und_image_embeds.contiguous().view(-1, c) + # und_image_embeds = mm_projector(und_image_embeds) + + # else: + # num_img = input_ids.shape[0] + # dummy = torch.zeros(num_img, 3, 384, 384 , dtype=gen_images.dtype, device=gen_images.device) # clip (3, 336, 336) + # temp = vision_tower(dummy) + # if 'early' in self.get_gen_pooling(): + # temp = temp[:,:64,:] + # num_img, _, c = temp.shape + # temp = temp.contiguous().view(-1, c) + # temp = mm_projector(temp) * 1e-20 + # latent_queries += temp + + + + + + image_idx = (input_ids == IMAGE_TOKEN_IDX) + und_image_idx = (input_ids == UND_IMAGE_TOKEN_IDX) + # img_indicator = torch.clone(image_idx) + output_indicator = labels != -100 + input_indicator = labels == -100 + # img_loss_indicator = torch.logical_and(output_indicator, image_idx) + # img_loss_indicator = torch.cat( + # [img_loss_indicator[:, 1:], img_loss_indicator[:, :1]], dim=1) + + # img_indicator = torch.cat( + # [img_indicator[:, 1:], img_indicator[:, :1]], dim=1) + + # if not target_image_embeds is None: + # target_image_embeds = target_image_embeds[-img_loss_indicator.sum():,:] + text_embeds = self.get_model().embed_tokens(input_ids) + # N_QUERY = self.get_n_query() + gen_img_idx = torch.logical_and(output_indicator, image_idx) + + # if not target_image_embeds is None: + text_embeds = text_embeds.clone() + text_embeds[gen_img_idx] = latent_queries + # text_embeds[gen_img_idx] = prompt_image_embeds.to(text_embeds.device)[:gen_img_idx.sum(),:] + # target_image_embeds = target_image_embeds.to(text_embeds.device)[:gen_img_idx.sum(),:] + + und_img_idx = torch.logical_and(input_indicator, und_image_idx) + + + if not und_images is None: + text_embeds[und_img_idx] = und_image_embeds.to(text_embeds.device)[:und_img_idx.sum(), :] + + labels[image_idx] = -100 + + + return None, position_ids, attention_mask, past_key_values, text_embeds, labels, target_image_embeds + + + + def initialize_vision_tokenizer(self, model_args, tokenizer): + if model_args.mm_use_im_patch_token: + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if model_args.mm_use_im_start_end: + num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = self.get_input_embeddings().weight.data + output_embeddings = self.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + if model_args.tune_mm_mlp_adapter: + for p in self.get_input_embeddings().parameters(): + p.requires_grad = True + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False + + if model_args.pretrain_mm_mlp_adapter: + mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') + embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] + assert num_new_tokens == 2 + if input_embeddings.shape == embed_tokens_weight.shape: + input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] + elif embed_tokens_weight.shape[0] == num_new_tokens: + input_embeddings[-num_new_tokens:] = embed_tokens_weight + else: + raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") + elif model_args.mm_use_im_patch_token: + if model_args.tune_mm_mlp_adapter: + for p in self.get_input_embeddings().parameters(): + p.requires_grad = False + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False diff --git a/blip3o/model/builder.py b/blip3o/model/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..0472676c97d466807e9cc6ce9c6086c12ce9f831 --- /dev/null +++ b/blip3o/model/builder.py @@ -0,0 +1,54 @@ +import os +import warnings +import shutil + +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig +import torch +from blip3o.model import * +from blip3o.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from blip3o.train.train import smart_tokenizer_and_embedding_resize + + +def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs): + kwargs = {"device_map": device_map, **kwargs} + + if device != "cuda": + kwargs['device_map'] = {"": device} + + if load_8bit: + kwargs['load_in_8bit'] = True + elif load_4bit: + kwargs['load_in_4bit'] = True + kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4' + ) + else: + kwargs['torch_dtype'] = torch.float16 + + if use_flash_attn: + kwargs['attn_implementation'] = 'flash_attention_2' + + + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + + model = blip3oQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=torch.float16).to('cuda:0') + + image_processor = None + + mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) + mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) + if mm_use_im_patch_token: + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + if mm_use_im_start_end: + tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + model.resize_token_embeddings(len(tokenizer)) + + if hasattr(model.config, "max_sequence_length"): + context_len = model.config.max_sequence_length + else: + context_len = 2048 + + return tokenizer, model, context_len \ No newline at end of file diff --git a/blip3o/model/consolidate.py b/blip3o/model/consolidate.py new file mode 100644 index 0000000000000000000000000000000000000000..eb2864b474756d3d16560f6f6698bea1e6f1c004 --- /dev/null +++ b/blip3o/model/consolidate.py @@ -0,0 +1,25 @@ +import argparse + +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from blip3o.model import * +from blip3o.model.utils import auto_upgrade + + +def consolidate_ckpt(src_path, dst_path): + print("Loading model") + auto_upgrade(src_path) + src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) + src_model.save_pretrained(dst_path) + src_tokenizer.save_pretrained(dst_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--src", type=str, required=True) + parser.add_argument("--dst", type=str, required=True) + + args = parser.parse_args() + + consolidate_ckpt(args.src, args.dst) diff --git a/blip3o/model/language_model/blip3o_llama.py b/blip3o/model/language_model/blip3o_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..c76d60f9f8c76632d1b27a4acaed4110d02e3790 --- /dev/null +++ b/blip3o/model/language_model/blip3o_llama.py @@ -0,0 +1,413 @@ +from typing import List, Optional, Tuple, Union +from PIL import Image + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers import AutoConfig, AutoModelForCausalLM, \ + LlamaConfig, LlamaModel, LlamaForCausalLM, AutoTokenizer + +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.generation.utils import GenerateOutput + +from ..blip3o_arch import blip3oMetaModel, blip3oMetaForCausalLM +from blip3o.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_IDX, DEFAULT_IM_START_TOKEN_IDX, DEFAULT_IM_END_TOKEN_IDX +import pdb +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import numpy_to_pil +import numpy as np +from diffusers.models import AutoencoderKL +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler + + +class blip3oConfig(LlamaConfig): + model_type = "blip3o_llama" + + +class blip3oLlamaModel(blip3oMetaModel, LlamaModel): + config_class = blip3oConfig + + def __init__(self, config: LlamaConfig): + super(blip3oLlamaModel, self).__init__(config) + + +class blip3oLlamaForCausalLM(LlamaForCausalLM, blip3oMetaForCausalLM): + config_class = blip3oConfig + + def __init__(self, config): + super(LlamaForCausalLM, self).__init__(config) + self.model = blip3oLlamaModel(config) + self.pretraining_tp = config.pretraining_tp + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.dist = None + + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + ids: Optional[list] = None, + i_s_pos: Optional[list] = None, + image_type: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + gen_image: Optional[torch.FloatTensor] = None, + und_image: Optional[torch.FloatTensor] = None, + image_sizes: Optional[List[List[int]]] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None + ) -> Union[Tuple, CausalLMOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels, + latents + ) = self.prepare_inputs_labels_for_multimodal( + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + gen_image, + und_image, + i_s_pos, + image_sizes + ) + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + total_loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = torch.nn.CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + + # compute image loss + # target_img_embeds = torch.clone(inputs_embeds.detach())[:,1:,:] # get target image emb + img_loss_funct = torch.nn.MSELoss() + # img_hidden_states = self.get_model().down_projector(hidden_states[:,-self.get_n_query():,:]) + img_hidden_states = [] + + for b in range(hidden_states.shape[0]): + img_hidden_states.append(hidden_states[b,i_s_pos[b]:i_s_pos[b]+64,:]) + img_hidden_states = torch.stack(img_hidden_states,dim=0) + img_hidden_states = self.get_model().down_projector(img_hidden_states) + # img_loss = 0.0 + if latents is None: + img_loss = img_loss_funct(img_hidden_states, torch.clone(img_hidden_states.detach())) + else: + bsz = latents.shape[0] + # device = latents.device + dtype = latents.dtype + noise = torch.randn_like(latents, device=latents.device) + u = torch.rand(size=(bsz,), device="cpu") + indices = (u * self.get_model().noise_scheduler.config.num_train_timesteps).long() + timesteps = self.get_model().noise_scheduler.timesteps[indices].to(device=latents.device) + sigmas = self.get_sigmas(timesteps, latents.device, n_dim=latents.ndim, dtype=dtype) + noisy_latents = (1.0 - sigmas) * latents + sigmas * noise + noise_pred = self.get_model().dit( + x=noisy_latents, + timestep=timesteps, + z_latents=self.mask_drop(img_hidden_states), + ) + target = noise - latents + img_loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean") + print(f"img loss {img_loss}, text loss {loss}") + total_loss = img_loss + + return CausalLMOutputWithPast( + loss=total_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + images: Optional[torch.Tensor] = None, + image_sizes: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + position_ids = kwargs.pop("position_ids", None) + attention_mask = kwargs.pop("attention_mask", None) + if "inputs_embeds" in kwargs: + raise NotImplementedError("`inputs_embeds` is not supported") + + if images is not None: + ( + inputs, + position_ids, + attention_mask, + _, + inputs_embeds, + img_indicator, + _ + ) = self.prepare_inputs_labels_for_understanding( + inputs, + position_ids, + attention_mask, + None, + None, + images, + image_sizes=image_sizes + ) + else: + inputs_embeds = self.get_model().embed_tokens(inputs) + + return super().generate( + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs + ) + + @torch.no_grad() + def generate_image( + self, + text: List[str], + tokenizer: AutoTokenizer, + image: Optional[torch.Tensor] = None, + max_var: Optional[float] = None, + # placeholder: str = DEFAULT_IMG_PLACEHOLDER, + ): + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler") + + vision_tower = self.get_vision_tower() + mm_projector = self.get_mm_projector() + N_QUERY = self.get_n_query() + + if image is not None: + # image: [Batch, 3, 448, 448] + prompt_image_embeds = vision_tower(batch_images) + num_img, _, c = prompt_image_embeds.shape # [batch, 576, 1024] + all_image_embeds = torch.clone(prompt_image_embeds).detach() + prompt_image_embeds = prompt_image_embeds.contiguous().view(-1, c) + prompt_image_embeds = mm_projector(prompt_image_embeds) + + inputs = tokenizer(text, padding="longest", return_tensors="pt") + device = self.get_model().device + attention_mask = inputs.attention_mask.to(device) + input_ids = inputs.input_ids.to(device) # B x N + input_ids = torch.cat([input_ids, torch.tensor([[198]]).to(device)], dim=1) + + # breakpoint() + text_embeds = self.get_model().embed_tokens(input_ids) + latent_queries = self.get_model().latent_queries.repeat(text_embeds.shape[0], 1, 1) + text_embeds = torch.cat([text_embeds, latent_queries], dim=1) + attention_mask = torch.cat([attention_mask, torch.ones_like(latent_queries[:, :, 0])], dim=1) + + outputs = self.model( + inputs_embeds=text_embeds, + # img_indicator=img_indicator, + # concept_indicator=concept_indicator if self.use_concept_token else None, + attention_mask=attention_mask, + output_hidden_states=True, + return_dict=True, + ) + hidden_states = outputs.hidden_states[-1][:,-N_QUERY:,:] + img_hidden_states = self.get_model().down_projector(hidden_states) + output_img = self.sample_images(img_hidden_states, scheduler) + output_img = output_img.view(1, 1792, -1).permute(0,2,1).contiguous() + + return output_img + + def sample_images( + self, + img_hidden_states, + scheduler, + guidance_scale: float = 3.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + num_inference_steps: int = 30, + num_images_per_prompt: int = 1, + return_tensor=False, + **kwargs, + ): + + device = img_hidden_states.device + dtype = img_hidden_states.dtype + + img_hidden_states_null = torch.zeros_like(img_hidden_states, device=device, dtype=dtype) + img_hidden_states_input = torch.cat([img_hidden_states_null, img_hidden_states], 0) + + batch_size = img_hidden_states.shape[0] + latent_size = self.get_model().dit.config.input_size + latent_channels = self.get_model().dit.config.in_channels + + latents = randn_tensor( + shape=(batch_size * num_images_per_prompt, latent_channels, latent_size, latent_size), + generator=generator, + device=device, + dtype=dtype, + ) + + # set step values + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + scheduler.set_timesteps(num_inference_steps, sigmas=sigmas) + + # Repeat z_latents and conditions for each image per prompt + img_hidden_states_input = img_hidden_states_input.repeat_interleave(num_images_per_prompt, dim=0) + + for t in scheduler.timesteps: + latent_model_input = latents.repeat(2, 1, 1, 1) + if hasattr(scheduler, "scale_model_input"): + latent_model_input = scheduler.scale_model_input(latent_model_input, t) + + # predict noise model_output + noise_pred = self.get_model().dit( + x=latent_model_input, + timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]).to(latent_model_input.device, torch.long), + z_latents=img_hidden_states_input, + ) + + # perform guidance + noise_pred_uncond, noise_pred = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) + + # compute previous image: x_t -> x_t-1 + latents = scheduler.step(noise_pred, t, latents).prev_sample + + # samples = self.decode_latents(latents, return_tensor=return_tensor) + return latents + + def decode_latents(self, latents, normalize=True, return_tensor=False): + if isinstance(self.get_model().vae, AutoencoderKL): + latents = latents / self.get_model().vae.config.scaling_factor + if self.get_model().vae.config.shift_factor is not None: + latents = latents + self.get_model().vae.config.shift_factor + latents = latents.to(dtype=torch.float32) + samples = self.get_model().vae.decode(latents).sample + else: + samples = self.get_model().vae.decode(latents) + if normalize: + samples = (samples / 2 + 0.5).clamp(0, 1) + else: + samples = samples.clamp(-1, 1) + if return_tensor: + return samples + samples = samples.cpu().permute(0, 2, 3, 1).float().numpy() + samples = numpy_to_pil(samples) + return samples + + def prepare_and_encode_inputs( + self, + inputs: List[str | Image.Image], + tokenizer: AutoTokenizer, + do_classifier_free_guidance: bool = False, + ): + # pdb.set_trace() + device = self.get_model().device + dtype = self.get_model().dtype + + has_image, has_text = False, False + text_prompt, image_prompt = "", [] + img_processor = self.get_vision_tower().image_processor + negative_prompt = {} + + for x in inputs: + if isinstance(x, str): + has_text = True + text_prompt += x + else: + has_image = True + text_prompt += DEFAULT_IMAGE_TOKEN + image_prompt.append(img_processor.preprocess(x, return_tensors='pt')['pixel_values']) + # pdb.set_trace() + if len(image_prompt) == 0: + image_prompt = None + else: + image_prompt = torch.cat(image_prompt) + image_prompt = image_prompt.type(dtype).to(device) + + if has_image and not has_text: + prompt = self.encode_images(image_prompt) + # pdb.set_trace() + if do_classifier_free_guidance: + key = "[NULL_IMAGE]" + if key not in negative_prompt: + negative_image = torch.zeros_like(image_prompt) + negative_prompt[key] = self.encode_images(negative_image) + prompt = torch.cat([prompt, negative_prompt[key]], dim=0) + else: + prompt = self.generate_image(text=[text_prompt], image=image_prompt, tokenizer=tokenizer) + if do_classifier_free_guidance: + key = "" + if key not in negative_prompt: + negative_prompt[key] = self.generate_image(text=[""], tokenizer=tokenizer) + prompt = torch.cat([prompt, negative_prompt[key]], dim=0) + + gen_pooling = self.get_gen_pooling() + n_query = self.get_n_query() + num_img, _, c = prompt.shape + if 'pool2d' in gen_pooling and has_text and not 'early' in gen_pooling: + stride = int(gen_pooling.split('_')[1]) + sqrt_n = int(n_query**0.5) + prompt = prompt.permute(0, 2, 1).reshape(num_img, -1, sqrt_n, sqrt_n) + prompt = F.avg_pool2d(prompt, kernel_size=(stride, stride), stride=stride) + prompt = prompt.reshape(num_img, c, -1).permute(0,2,1) + return prompt + + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, + inputs_embeds=None, **kwargs): + images = kwargs.pop("images", None) + image_sizes = kwargs.pop("image_sizes", None) + inputs = super().prepare_inputs_for_generation( + input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs + ) + if images is not None: + inputs['images'] = images + if image_sizes is not None: + inputs['image_sizes'] = image_sizes + return inputs + +AutoConfig.register("blip3o_llama", blip3oConfig) +AutoModelForCausalLM.register(blip3oConfig, blip3oLlamaForCausalLM) diff --git a/blip3o/model/language_model/blip3o_qwen.py b/blip3o/model/language_model/blip3o_qwen.py new file mode 100644 index 0000000000000000000000000000000000000000..c66b1a09b18b7f95dc472e670316857e225b90b9 --- /dev/null +++ b/blip3o/model/language_model/blip3o_qwen.py @@ -0,0 +1,420 @@ +from typing import List, Optional, Tuple, Union, Dict +import torch +import torch.nn as nn +from PIL import Image +import torch.nn.functional as F + + +import transformers +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.generation.utils import GenerateOutput + +from blip3o.model.blip3o_arch import blip3oMetaModel, blip3oMetaForCausalLM + +from transformers import Qwen2_5_VLConfig, Qwen2_5_VLModel, Qwen2_5_VLForConditionalGeneration + +from blip3o.constants import UND_IMAGE_TOKEN_IDX + + + +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import numpy_to_pil +import numpy as np +from diffusers.models import AutoencoderKL +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler + + +class blip3oQwenConfig(Qwen2_5_VLConfig): + model_type = "blip3o_qwen" + + +class blip3oQwenModel(blip3oMetaModel, Qwen2_5_VLModel): + config_class = blip3oQwenConfig + + def __init__(self, config: Qwen2_5_VLConfig): + super(blip3oQwenModel, self).__init__(config) + + +class blip3oQwenForCausalLM(Qwen2_5_VLForConditionalGeneration, blip3oMetaForCausalLM): + config_class = blip3oQwenConfig + + def __init__(self, config): + Qwen2_5_VLForConditionalGeneration.__init__(self, config) + config.model_type = "blip3o_qwen" + + self.model = blip3oQwenModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + ids: Optional[list] = None, + i_s_pos: Optional[list] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + gen_image: Optional[torch.FloatTensor] = None, + und_image: Optional[torch.FloatTensor] = None, + grid_thw: Optional[torch.FloatTensor] = None, + image_sizes: Optional[List[List[int]]] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None + ) -> Union[Tuple, CausalLMOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels, + latents + ) = self.prepare_inputs_labels_for_multimodal( + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + gen_image, + und_image, + grid_thw, + i_s_pos, + image_sizes + ) + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + total_loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = torch.nn.CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + + # compute image loss + # target_img_embeds = torch.clone(inputs_embeds.detach())[:,1:,:] # get target image emb + img_loss_funct = torch.nn.MSELoss() + # img_hidden_states = self.get_model().down_projector(hidden_states[:,-self.get_n_query():,:]) + img_hidden_states = [] + + for b in range(hidden_states.shape[0]): + img_hidden_states.append(hidden_states[b,i_s_pos[b]:i_s_pos[b]+64,:]) + img_hidden_states = torch.stack(img_hidden_states,dim=0) + img_hidden_states = self.get_model().down_projector(img_hidden_states) + # img_loss = 0.0 + if latents is None: + img_loss = img_loss_funct(img_hidden_states, torch.clone(img_hidden_states.detach())) + else: + bsz = latents.shape[0] + # device = latents.device + dtype = latents.dtype + noise = torch.randn_like(latents, device=latents.device) + u = torch.rand(size=(bsz,), device="cpu") + indices = (u * self.get_model().noise_scheduler.config.num_train_timesteps).long() + timesteps = self.get_model().noise_scheduler.timesteps[indices].to(device=latents.device) + sigmas = self.get_sigmas(timesteps, latents.device, n_dim=latents.ndim, dtype=dtype) + noisy_latents = (1.0 - sigmas) * latents + sigmas * noise + noise_pred = self.get_model().dit( + x=noisy_latents, + timestep=timesteps, + z_latents=self.mask_drop(img_hidden_states), + ) + target = noise - latents + img_loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean") + print(f"img loss {img_loss}") + total_loss = img_loss + + return CausalLMOutputWithPast( + loss=total_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + images: Optional[torch.Tensor] = None, + image_sizes: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + position_ids = kwargs.pop("position_ids", None) + attention_mask = kwargs.pop("attention_mask", None) + if "inputs_embeds" in kwargs: + raise NotImplementedError("`inputs_embeds` is not supported") + + if images is not None: + ( + inputs, + position_ids, + attention_mask, + _, + inputs_embeds, + img_indicator, + _ + ) = self.prepare_inputs_labels_for_understanding( + inputs, + position_ids, + attention_mask, + None, + None, + images, + image_sizes=image_sizes + ) + else: + inputs_embeds = self.get_model().embed_tokens(inputs) + + return super().generate( + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs + ) + + @torch.no_grad() + def generate_image( + self, + text: List[str], + tokenizer: AutoTokenizer, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + max_var: Optional[float] = None, + # placeholder: str = DEFAULT_IMG_PLACEHOLDER, + ): + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler") + + + N_QUERY = self.get_n_query() + inputs = tokenizer(text, padding="longest", return_tensors="pt") + device = self.get_model().device + attention_mask = inputs.attention_mask.to(device) + input_ids = inputs.input_ids.to(device) # B x N + input_ids = torch.cat([input_ids, torch.tensor([[151665]]).to(device)], dim=1) + # breakpoint() + + + text_embeds = self.get_model().embed_tokens(input_ids) + latent_queries = self.get_model().latent_queries.repeat(text_embeds.shape[0], 1, 1) + + + if pixel_values is not None: + und_image_idx = (input_ids == UND_IMAGE_TOKEN_IDX) + pixel_values = pixel_values.type(self.visual.dtype) + und_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + text_embeds[und_image_idx] = und_image_embeds.to(text_embeds.device)[:und_image_idx.sum(), :] + + + text_embeds = torch.cat([text_embeds, latent_queries], dim=1) + attention_mask = torch.cat([attention_mask, torch.ones_like(latent_queries[:, :, 0])], dim=1) + + + outputs = self.model( + inputs_embeds=text_embeds, + attention_mask=attention_mask, + output_hidden_states=True, + return_dict=True, + ) + hidden_states = outputs.hidden_states[-1][:,-N_QUERY:,:] + img_hidden_states = hidden_states + output_img = self.sample_images(img_hidden_states, scheduler) + output_img = output_img.view(1, 1792, -1).permute(0,2,1).contiguous() + + return output_img + + def sample_images( + self, + img_hidden_states, + scheduler, + guidance_scale: float = 3.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + num_inference_steps: int = 30, + num_images_per_prompt: int = 1, + return_tensor=False, + **kwargs, + ): + + device = img_hidden_states.device + dtype = img_hidden_states.dtype + + + img_hidden_states_null = torch.zeros_like(img_hidden_states, device=device, dtype=dtype) + img_hidden_states_input = torch.cat([img_hidden_states_null, img_hidden_states], 0) + + batch_size = img_hidden_states.shape[0] + latent_size = self.get_model().dit.config.input_size + latent_channels = self.get_model().dit.config.in_channels + + latents = randn_tensor( + shape=(batch_size * num_images_per_prompt, latent_channels, latent_size, latent_size), + generator=generator, + device=device, + dtype=dtype, + ) + + # set step values + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + scheduler.set_timesteps(num_inference_steps, sigmas=sigmas) + + # Repeat z_latents and conditions for each image per prompt + img_hidden_states_input = img_hidden_states_input.repeat_interleave(num_images_per_prompt, dim=0) + + for t in scheduler.timesteps: + latent_model_input = latents.repeat(2, 1, 1, 1) + if hasattr(scheduler, "scale_model_input"): + latent_model_input = scheduler.scale_model_input(latent_model_input, t) + + # predict noise model_output + noise_pred = self.get_model().dit( + x=latent_model_input, + timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]).to(latent_model_input.device, torch.long), + z_latents=img_hidden_states_input, + ) + + # perform guidance + noise_pred_uncond, noise_pred = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) + + # compute previous image: x_t -> x_t-1 + latents = scheduler.step(noise_pred, t, latents).prev_sample + + # samples = self.decode_latents(latents, return_tensor=return_tensor) + # breakpoint() + return latents + + def decode_latents(self, latents, normalize=True, return_tensor=False): + if isinstance(self.get_model().vae, AutoencoderKL): + latents = latents / self.get_model().vae.config.scaling_factor + if self.get_model().vae.config.shift_factor is not None: + latents = latents + self.get_model().vae.config.shift_factor + latents = latents.to(dtype=torch.float32) + samples = self.get_model().vae.decode(latents).sample + else: + samples = self.get_model().vae.decode(latents) + if normalize: + samples = (samples / 2 + 0.5).clamp(0, 1) + else: + samples = samples.clamp(-1, 1) + if return_tensor: + return samples + samples = samples.cpu().permute(0, 2, 3, 1).float().numpy() + samples = numpy_to_pil(samples) + return samples + + def prepare_and_encode_inputs( + self, + inputs: List[str | Image.Image], + tokenizer: AutoTokenizer, + do_classifier_free_guidance: bool = False, + ): + # pdb.set_trace() + device = self.get_model().device + dtype = self.get_model().dtype + + has_image, has_text = False, False + text_prompt, image_prompt = "", [] + img_processor = self.get_vision_tower().image_processor + negative_prompt = {} + + for x in inputs: + if isinstance(x, str): + has_text = True + text_prompt += x + else: + has_image = True + text_prompt += DEFAULT_IMAGE_TOKEN + image_prompt.append(img_processor.preprocess(x, return_tensors='pt')['pixel_values']) + # pdb.set_trace() + if len(image_prompt) == 0: + image_prompt = None + else: + image_prompt = torch.cat(image_prompt) + image_prompt = image_prompt.type(dtype).to(device) + + if has_image and not has_text: + prompt = self.encode_images(image_prompt) + # pdb.set_trace() + if do_classifier_free_guidance: + key = "[NULL_IMAGE]" + if key not in negative_prompt: + negative_image = torch.zeros_like(image_prompt) + negative_prompt[key] = self.encode_images(negative_image) + prompt = torch.cat([prompt, negative_prompt[key]], dim=0) + else: + prompt = self.generate_image(text=[text_prompt], image=image_prompt, tokenizer=tokenizer) + if do_classifier_free_guidance: + key = "" + if key not in negative_prompt: + negative_prompt[key] = self.generate_image(text=[""], tokenizer=tokenizer) + prompt = torch.cat([prompt, negative_prompt[key]], dim=0) + + gen_pooling = self.get_gen_pooling() + n_query = self.get_n_query() + num_img, _, c = prompt.shape + if 'pool2d' in gen_pooling and has_text and not 'early' in gen_pooling: + stride = int(gen_pooling.split('_')[1]) + sqrt_n = int(n_query**0.5) + prompt = prompt.permute(0, 2, 1).reshape(num_img, -1, sqrt_n, sqrt_n) + prompt = F.avg_pool2d(prompt, kernel_size=(stride, stride), stride=stride) + prompt = prompt.reshape(num_img, c, -1).permute(0,2,1) + return prompt + + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, + inputs_embeds=None, **kwargs): + images = kwargs.pop("images", None) + image_sizes = kwargs.pop("image_sizes", None) + inputs = super().prepare_inputs_for_generation( + input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs + ) + if images is not None: + inputs['images'] = images + if image_sizes is not None: + inputs['image_sizes'] = image_sizes + return inputs + +AutoConfig.register("blip3o_qwen", blip3oQwenConfig) +AutoModelForCausalLM.register(blip3oQwenConfig, blip3oQwenForCausalLM) diff --git a/blip3o/model/lumina_nextdit2d.py b/blip3o/model/lumina_nextdit2d.py new file mode 100644 index 0000000000000000000000000000000000000000..d6f87aa22cbbf909878ebfb80d72a69837e23127 --- /dev/null +++ b/blip3o/model/lumina_nextdit2d.py @@ -0,0 +1,365 @@ +# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.attention import LuminaFeedForward +from diffusers.models.attention_processor import Attention, LuminaAttnProcessor2_0 +from diffusers.models.embeddings import LuminaCombinedTimestepCaptionEmbedding, LuminaPatchEmbed, PixArtAlphaTextProjection + +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm +from diffusers.utils import is_torch_version, logging + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class LuminaNextDiTBlock(nn.Module): + """ + A LuminaNextDiTBlock for LuminaNextDiT2DModel. + + Parameters: + dim (`int`): Embedding dimension of the input features. + num_attention_heads (`int`): Number of attention heads. + num_kv_heads (`int`): + Number of attention heads in key and value features (if using GQA), or set to None for the same as query. + multiple_of (`int`): The number of multiple of ffn layer. + ffn_dim_multiplier (`float`): The multipier factor of ffn layer dimension. + norm_eps (`float`): The eps for norm layer. + qk_norm (`bool`): normalization for query and key. + cross_attention_dim (`int`): Cross attention embedding dimension of the input text prompt hidden_states. + norm_elementwise_affine (`bool`, *optional*, defaults to True), + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + qk_norm: bool, + cross_attention_dim: int, + norm_elementwise_affine: bool = True, + ) -> None: + super().__init__() + self.head_dim = dim // num_attention_heads + + self.gate = nn.Parameter(torch.zeros([num_attention_heads])) + + # Self-attention + self.attn1 = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + qk_norm="layer_norm_across_heads" if qk_norm else None, + heads=num_attention_heads, + kv_heads=num_kv_heads, + eps=1e-5, + bias=False, + out_bias=False, + processor=LuminaAttnProcessor2_0(), + ) + self.attn1.to_out = nn.Identity() + + # Cross-attention + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + dim_head=dim // num_attention_heads, + qk_norm="layer_norm_across_heads" if qk_norm else None, + heads=num_attention_heads, + kv_heads=num_kv_heads, + eps=1e-5, + bias=False, + out_bias=False, + processor=LuminaAttnProcessor2_0(), + ) + + self.feed_forward = LuminaFeedForward( + dim=dim, + inner_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ) + + self.norm1 = LuminaRMSNormZero( + embedding_dim=dim, + norm_eps=norm_eps, + norm_elementwise_affine=norm_elementwise_affine, + ) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + + self.norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + + self.norm1_context = RMSNorm(cross_attention_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + image_rotary_emb: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_mask: torch.Tensor, + temb: torch.Tensor, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Perform a forward pass through the LuminaNextDiTBlock. + + Parameters: + hidden_states (`torch.Tensor`): The input of hidden_states for LuminaNextDiTBlock. + attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask. + image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies. + encoder_hidden_states: (`torch.Tensor`): The hidden_states of text prompt are processed by Gemma encoder. + encoder_mask (`torch.Tensor`): The hidden_states of text prompt attention mask. + temb (`torch.Tensor`): Timestep embedding with text prompt embedding. + cross_attention_kwargs (`Dict[str, Any]`): kwargs for cross attention. + """ + residual = hidden_states + + # Self-attention + norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) + self_attn_output = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + query_rotary_emb=image_rotary_emb, + key_rotary_emb=image_rotary_emb, + **cross_attention_kwargs, + ) + + # Cross-attention + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) + cross_attn_output = self.attn2( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=encoder_mask, + query_rotary_emb=image_rotary_emb, + key_rotary_emb=None, + **cross_attention_kwargs, + ) + cross_attn_output = cross_attn_output * self.gate.tanh().view(1, 1, -1, 1) + mixed_attn_output = self_attn_output + cross_attn_output + mixed_attn_output = mixed_attn_output.flatten(-2) + # linear proj + hidden_states = self.attn2.to_out[0](mixed_attn_output) + + hidden_states = residual + gate_msa.unsqueeze(1).tanh() * self.norm2(hidden_states) + + mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) + + hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) + + return hidden_states + + +class LuminaNextDiT2DModel(ModelMixin, ConfigMixin): + """ + LuminaNextDiT: Diffusion model with a Transformer backbone. + + Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. + + Parameters: + sample_size (`int`): The width of the latent images. This is fixed during training since + it is used to learn a number of position embeddings. + patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2): + The size of each patch in the image. This parameter defines the resolution of patches fed into the model. + in_channels (`int`, *optional*, defaults to 4): + The number of input channels for the model. Typically, this matches the number of channels in the input + images. + hidden_size (`int`, *optional*, defaults to 4096): + The dimensionality of the hidden layers in the model. This parameter determines the width of the model's + hidden representations. + num_layers (`int`, *optional*, default to 32): + The number of layers in the model. This defines the depth of the neural network. + num_attention_heads (`int`, *optional*, defaults to 32): + The number of attention heads in each attention layer. This parameter specifies how many separate attention + mechanisms are used. + num_kv_heads (`int`, *optional*, defaults to 8): + The number of key-value heads in the attention mechanism, if different from the number of attention heads. + If None, it defaults to num_attention_heads. + multiple_of (`int`, *optional*, defaults to 256): + A factor that the hidden size should be a multiple of. This can help optimize certain hardware + configurations. + ffn_dim_multiplier (`float`, *optional*): + A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on + the model configuration. + norm_eps (`float`, *optional*, defaults to 1e-5): + A small value added to the denominator for numerical stability in normalization layers. + learn_sigma (`bool`, *optional*, defaults to True): + Whether the model should learn the sigma parameter, which might be related to uncertainty or variance in + predictions. + qk_norm (`bool`, *optional*, defaults to True): + Indicates if the queries and keys in the attention mechanism should be normalized. + cross_attention_dim (`int`, *optional*, defaults to 2048): + The dimensionality of the text embeddings. This parameter defines the size of the text representations used + in the model. + scaling_factor (`float`, *optional*, defaults to 1.0): + A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the + overall scale of the model's operations. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["LuminaNextDiTBlock"] + + @register_to_config + def __init__( + self, + sample_size: int = 128, + patch_size: Optional[int] = 2, + in_channels: Optional[int] = 4, + hidden_size: Optional[int] = 2304, + num_layers: Optional[int] = 32, # 32 + num_attention_heads: Optional[int] = 32, # 32 + num_kv_heads: Optional[int] = None, + multiple_of: Optional[int] = 256, + ffn_dim_multiplier: Optional[float] = None, + norm_eps: Optional[float] = 1e-5, + learn_sigma: Optional[bool] = True, + qk_norm: Optional[bool] = True, + cross_attention_dim: Optional[int] = 2048, + scaling_factor: Optional[float] = 1.0, + ) -> None: + super().__init__() + self.sample_size = sample_size + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.head_dim = hidden_size // num_attention_heads + self.scaling_factor = scaling_factor + self.gradient_checkpointing = False + + self.caption_projection = PixArtAlphaTextProjection(in_features=cross_attention_dim, hidden_size=hidden_size) + self.patch_embedder = LuminaPatchEmbed(patch_size=patch_size, in_channels=in_channels, embed_dim=hidden_size, bias=True) + + self.time_caption_embed = LuminaCombinedTimestepCaptionEmbedding(hidden_size=min(hidden_size, 1024), cross_attention_dim=hidden_size) + + self.layers = nn.ModuleList( + [ + LuminaNextDiTBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + hidden_size, + ) + for _ in range(num_layers) + ] + ) + self.norm_out = LuminaLayerNormContinuous( + embedding_dim=hidden_size, + conditioning_embedding_dim=min(hidden_size, 1024), + elementwise_affine=False, + eps=1e-6, + bias=True, + out_dim=patch_size * patch_size * self.out_channels, + ) + # self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels) + + assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4" + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_mask: torch.Tensor, + image_rotary_emb: torch.Tensor, + cross_attention_kwargs: Dict[str, Any] = None, + return_dict=True, + ) -> torch.Tensor: + """ + Forward pass of LuminaNextDiT. + + Parameters: + hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W). + timestep (torch.Tensor): Tensor of diffusion timesteps of shape (N,). + encoder_hidden_states (torch.Tensor): Tensor of caption features of shape (N, D). + encoder_mask (torch.Tensor): Tensor of caption masks of shape (N, L). + """ + hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb) + image_rotary_emb = image_rotary_emb.to(hidden_states.device) + # breakpoint() + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + temb = self.time_caption_embed(timestep, encoder_hidden_states, encoder_mask) + + encoder_mask = encoder_mask.bool() + + for layer in self.layers: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer), + hidden_states, + mask, + image_rotary_emb, + encoder_hidden_states, + encoder_mask, + temb, + cross_attention_kwargs, + **ckpt_kwargs, + ) + else: + hidden_states = layer( + hidden_states, + mask, + image_rotary_emb, + encoder_hidden_states, + encoder_mask, + temb=temb, + cross_attention_kwargs=cross_attention_kwargs, + ) + + hidden_states = self.norm_out(hidden_states, temb) + + # unpatchify + height_tokens = width_tokens = self.patch_size + height, width = img_size[0] + batch_size = hidden_states.size(0) + sequence_length = (height // height_tokens) * (width // width_tokens) + hidden_states = hidden_states[:, :sequence_length].view( + batch_size, height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels + ) + output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/blip3o/model/make_delta.py b/blip3o/model/make_delta.py new file mode 100644 index 0000000000000000000000000000000000000000..b9c95bfd3578a92569f5512251b425b15729fcc1 --- /dev/null +++ b/blip3o/model/make_delta.py @@ -0,0 +1,48 @@ +import argparse + +import torch +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM +from blip3o.model.utils import auto_upgrade + + +def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): + print("Loading base model") + base = AutoModelForCausalLM.from_pretrained( + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + + print("Loading target model") + auto_upgrade(target_model_path) + target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + + print("Calculating delta") + for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): + if name not in base.state_dict(): + assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' + continue + if param.data.shape == base.state_dict()[name].shape: + param.data -= base.state_dict()[name] + else: + assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' + bparam = base.state_dict()[name] + param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam + + print("Saving delta") + if hub_repo_id: + kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} + else: + kwargs = {} + target.save_pretrained(delta_path, **kwargs) + target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) + target_tokenizer.save_pretrained(delta_path, **kwargs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--delta-path", type=str, required=True) + parser.add_argument("--hub-repo-id", type=str, default=None) + args = parser.parse_args() + + make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) diff --git a/blip3o/model/multimodal_encoder/builder.py b/blip3o/model/multimodal_encoder/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..f773b6f919123a0545a2a9fdd5bbc15e6a0f0a4b --- /dev/null +++ b/blip3o/model/multimodal_encoder/builder.py @@ -0,0 +1,63 @@ +import os +from .clip_encoder import CLIPVisionTower +from .imagebind import ImageBindWrapper +from .open_clip_encoder import OpenCLIPVisionTower +from .siglip_encoder import SigLipVisionTower +from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2 + +from .eva_clip.eva_clip_encoder import EvaClipVisionTower +from .dev_eva_clip.eva_vit import EvaViTWrapper + +from blip3o.model.nextdit_crossattn import NextDiTCrossAttnConfig, NextDiTCrossAttn +from diffusers.models import AutoencoderKL +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler + + +def build_vision_tower(vision_tower_cfg, **kwargs): + vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) + is_absolute_path_exists = os.path.exists(vision_tower) + use_s2 = getattr(vision_tower_cfg, 's2', False) + if "siglip" in vision_tower: + return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs) + if "eva" in vision_tower: + return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) + if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: + if use_s2: + return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) + else: + return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) + + raise ValueError(f'Unknown vision tower: {vision_tower}') + + + + +def build_gen_vision_tower(vision_tower_cfg, **kwargs): + vision_tower = getattr(vision_tower_cfg, 'gen_vision_tower') + is_absolute_path_exists = os.path.exists(vision_tower) + use_s2 = getattr(vision_tower_cfg, 's2', False) + if "siglip" in vision_tower: + return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs) + if "eva" in vision_tower: + return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) + if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: + if use_s2: + return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) + else: + return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) + + raise ValueError(f'Unknown vision tower: {vision_tower}') + + + +def build_dit(vision_tower_cfg, **kwargs): + vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae") + # vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae") + dit = NextDiTCrossAttn(NextDiTCrossAttnConfig()) + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler") + # scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", subfolder="scheduler") + vae.eval() + vae.requires_grad_(False) + return dit, vae, noise_scheduler + + diff --git a/blip3o/model/multimodal_encoder/clip_encoder.py b/blip3o/model/multimodal_encoder/clip_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1ef7f9a532f9a0042bc891f32048807e4bb53cb5 --- /dev/null +++ b/blip3o/model/multimodal_encoder/clip_encoder.py @@ -0,0 +1,172 @@ +import torch +import torch.nn as nn +from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig + +try: + from s2wrapper import forward as multiscale_forward +except: + pass + + +class CLIPVisionTower(nn.Module): + def __init__(self, vision_tower, args, delay_load=False): + super().__init__() + + self.is_loaded = False + + self.vision_tower_name = vision_tower + self.select_layer = args.mm_vision_select_layer + self.select_feature = getattr(args, "mm_vision_select_feature", "patch") + + if not delay_load: + print(f"Loading vision tower: {vision_tower}") + self.load_model() + elif getattr(args, "unfreeze_mm_vision_tower", False): + # TODO: better detector is needed. + print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") + self.load_model() + elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: + print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") + self.load_model() + else: + self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) + + def load_model(self, device_map=None): + if self.is_loaded: + print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name)) + return + + self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) + self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) + self.vision_tower.requires_grad_(False) + + self.is_loaded = True + + def feature_select(self, image_forward_outs): + select_feature_type = self.select_feature + + if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: + select_every_k_layer = len(image_forward_outs.hidden_states) // 4 + image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1) + select_feature_type = select_feature_type.replace("slicefour_", "") + elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]: + select_layers = [-2, -5, -8, -11, 6] + image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1) + select_feature_type = select_feature_type.replace("slice_m25811_f6_", "") + else: + image_features = image_forward_outs.hidden_states[self.select_layer] + + if select_feature_type == "patch": + image_features = image_features[:, 1:] + elif select_feature_type == "cls_patch": + image_features = image_features + else: + raise ValueError(f"Unexpected select feature: {select_feature_type}") + return image_features + + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) + image_feature = self.feature_select(image_forward_out).to(image.dtype) + image_features.append(image_feature) + else: + image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) + image_features = self.feature_select(image_forward_outs).to(images.dtype) + + return image_features + + @property + def dummy_feature(self): + return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + return self.vision_tower.dtype + + @property + def device(self): + return self.vision_tower.device + + @property + def config(self): + if self.is_loaded: + return self.vision_tower.config + else: + return self.cfg_only + + @property + def hidden_size(self): + _hidden_size = self.config.hidden_size + if "slicefour" in self.select_feature: + _hidden_size *= 4 + if "slice_m25811_f6" in self.select_feature: + _hidden_size *= 5 + return _hidden_size + + @property + def num_patches_per_side(self): + return self.config.image_size // self.config.patch_size + + @property + def num_patches(self): + _num_patches = (self.config.image_size // self.config.patch_size) ** 2 + if "cls_patch" in self.select_feature: + _num_patches += 1 + return _num_patches + + @property + def image_size(self): + return self.config.image_size + + +class CLIPVisionTowerS2(CLIPVisionTower): + def __init__(self, vision_tower, args, delay_load=False): + + self.s2_scales = getattr(args, "s2_scales", "336,672,1008") + self.s2_scales = list(map(int, self.s2_scales.split(","))) + self.s2_scales.sort() + self.s2_split_size = self.s2_scales[0] + self.s2_image_size = self.s2_scales[-1] + + super().__init__(vision_tower, args, delay_load) + + # change resize/crop size in preprocessing to the largest image size in s2_scale + if not delay_load or getattr(args, "unfreeze_mm_vision_tower", False): + self.image_processor.size["shortest_edge"] = self.s2_image_size + self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size + + def load_model(self, device_map=None): + if self.is_loaded: + print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name)) + return + + self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) + self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) + self.vision_tower.requires_grad_(False) + + self.image_processor.size["shortest_edge"] = self.s2_image_size + self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size + + self.is_loaded = True + + def forward_feature(self, images): + image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) + image_features = self.feature_select(image_forward_outs).to(images.dtype) + return image_features + + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_feature = multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True) + image_features.append(image_feature) + else: + image_features = multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True) + + return image_features + + @property + def hidden_size(self): + return self.config.hidden_size * len(self.s2_scales) diff --git a/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/__init__.py b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ede6900543160f529b87001346182b526844596c --- /dev/null +++ b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/__init__.py @@ -0,0 +1,9 @@ +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer +from .factory import list_models, add_model_config, get_model_config, load_checkpoint +from .loss import ClipLoss +from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype +from .openai import load_openai_model, list_openai_models +from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained +from .tokenizer import SimpleTokenizer, tokenize +from .transform import image_transform diff --git a/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/constants.py b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..a670bb3fab442baeb9af53b91c312e6982af57ee --- /dev/null +++ b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/constants.py @@ -0,0 +1,2 @@ +OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) diff --git a/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/eva_vit_model.py b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/eva_vit_model.py new file mode 100644 index 0000000000000000000000000000000000000000..23cb38c9230f92cf5e0601a95fb6a610a9e2185c --- /dev/null +++ b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/eva_vit_model.py @@ -0,0 +1,571 @@ +# -------------------------------------------------------- +# Adapted from https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- +import math +import os +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + from timm.models.layers import drop_path, to_2tuple, trunc_normal_ +except: + from timm.layers import drop_path, to_2tuple, trunc_normal_ + +from .transformer import PatchDropout +from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast + +if os.getenv("ENV_TYPE") == "deepspeed": + try: + from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint + except: + from torch.utils.checkpoint import checkpoint +else: + from torch.utils.checkpoint import checkpoint + +try: + import xformers.ops as xops +except ImportError: + xops = None + # print("Please 'pip install xformers'") + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + drop=0.0, + subln=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + + self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity() + + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + # x = self.drop(x) + # commit this for the orignal BERT implement + x = self.ffn_ln(x) + + x = self.fc2(x) + x = self.drop(x) + return x + + +class SwiGLU(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.0, norm_layer=nn.LayerNorm, subln=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.w1 = nn.Linear(in_features, hidden_features) + self.w2 = nn.Linear(in_features, hidden_features) + + self.act = act_layer() + self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity() + self.w3 = nn.Linear(hidden_features, out_features) + + self.drop = nn.Dropout(drop) + + def forward(self, x): + x1 = self.w1(x) + x2 = self.w2(x) + hidden = self.act(x1) * x2 + x = self.ffn_ln(hidden) + x = self.w3(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.subln = subln + if self.subln: + self.q_proj = nn.Linear(dim, all_head_dim, bias=False) + self.k_proj = nn.Linear(dim, all_head_dim, bias=False) + self.v_proj = nn.Linear(dim, all_head_dim, bias=False) + else: + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.v_bias = None + + if window_size: + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + self.attn_drop = nn.Dropout(attn_drop) + self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity() + # self.proj = nn.Linear(all_head_dim, all_head_dim) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.xattn = xattn + self.xattn_drop = attn_drop + + self.rope = rope + + def forward(self, x, rel_pos_bias=None, attn_mask=None): + B, N, C = x.shape + if self.subln: + q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias) + k = F.linear(input=x, weight=self.k_proj.weight, bias=None) + v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias) + + q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C + k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) + v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) + else: + + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C + q, k, v = qkv[0], qkv[1], qkv[2] + + if self.rope: + # slightly fast impl + q_t = q[:, :, 1:, :] + ro_q_t = self.rope(q_t) + q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v) + + k_t = k[:, :, 1:, :] + ro_k_t = self.rope(k_t) + k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v) + + if self.xattn: + q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + x = xops.memory_efficient_attention( + q, + k, + v, + p=self.xattn_drop, + scale=self.scale, + ) + x = x.reshape(B, N, -1) + x = self.inner_attn_ln(x) + x = self.proj(x) + x = self.proj_drop(x) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + if self.relative_position_bias_table is not None: + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0).type_as(attn) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias.type_as(attn) + + if attn_mask is not None: + attn_mask = attn_mask.bool() + attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf")) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.inner_attn_ln(x) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + init_values=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + window_size=None, + attn_head_dim=None, + xattn=False, + rope=None, + postnorm=False, + subln=False, + naiveswiglu=False, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim, xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer + ) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + if naiveswiglu: + self.mlp = SwiGLU( + in_features=dim, + hidden_features=mlp_hidden_dim, + subln=subln, + norm_layer=norm_layer, + ) + else: + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, subln=subln, drop=drop) + + if init_values is not None and init_values > 0: + self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + self.postnorm = postnorm + + def forward(self, x, rel_pos_bias=None, attn_mask=None): + if self.gamma_1 is None: + if self.postnorm: + x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))) + x = x + self.drop_path(self.norm2(self.mlp(x))) + else: + x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + if self.postnorm: + x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))) + x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class RelativePositionBias(nn.Module): + + def __init__(self, window_size, num_heads): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + + def forward(self): + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +class EVAVisionTransformer(nn.Module): + """Vision Transformer with support for patch or hybrid CNN input stage""" + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + init_values=None, + patch_dropout=0.0, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + use_shared_rel_pos_bias=False, + rope=False, + use_mean_pooling=True, + init_scale=0.001, + grad_checkpointing=False, + xattn=False, + postnorm=False, + pt_hw_seq_len=16, + intp_freq=False, + naiveswiglu=False, + subln=False, + ): + super().__init__() + self.image_size = img_size + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + if use_abs_pos_emb: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + else: + self.pos_embed = None + self.pos_drop = nn.Dropout(p=drop_rate) + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) + else: + self.rel_pos_bias = None + + if rope: + half_head_dim = embed_dim // num_heads // 2 + hw_seq_len = img_size // patch_size + self.rope = VisionRotaryEmbeddingFast( + dim=half_head_dim, + pt_seq_len=pt_hw_seq_len, + ft_seq_len=hw_seq_len if intp_freq else None, + # patch_dropout=patch_dropout + ) + else: + self.rope = None + + self.naiveswiglu = naiveswiglu + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.use_rel_pos_bias = use_rel_pos_bias + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + init_values=init_values, + window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None, + xattn=xattn, + rope=self.rope, + postnorm=postnorm, + subln=subln, + naiveswiglu=naiveswiglu, + ) + for i in range(depth) + ] + ) + self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) + self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None + self.head = nn.Linear(embed_dim, num_classes, bias=qkv_bias) if num_classes > 0 else nn.Identity() + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=0.02) + + trunc_normal_(self.cls_token, std=0.02) + + self.apply(self._init_weights) + self.fix_init_weight() + + if isinstance(self.head, nn.Linear): + trunc_normal_(self.head.weight, std=0.02) + self.head.weight.data.mul_(init_scale) + if self.head.bias is not None: + self.head.bias.data.mul_(init_scale) + + # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity() + + self.grad_checkpointing = grad_checkpointing + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + if self.naiveswiglu: + rescale(layer.mlp.w3.weight.data, layer_id + 1) + else: + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def get_cast_dtype(self) -> torch.dtype: + return self.blocks[0].mlp.fc2.weight.dtype + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_num_layers(self): + return len(self.blocks) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert unlocked_groups == 0, "partial locking not currently supported for this model" + for param in self.parameters(): + param.requires_grad = False + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed", "cls_token"} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=""): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x, return_all_features=False): + + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + # if os.getenv("RoPE") == "1": + # if self.training and not isinstance(self.patch_dropout, nn.Identity): + # x, patch_indices_keep = self.patch_dropout(x) + # self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep) + # else: + # self.rope.forward = partial(self.rope.forward, patch_indices_keep=None) + # x = self.patch_dropout(x) + # else: + x = self.patch_dropout(x) + + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + if self.grad_checkpointing: + x = checkpoint(blk, x, (rel_pos_bias,)) + else: + x = blk(x, rel_pos_bias=rel_pos_bias) + + if not return_all_features: + x = self.norm(x) + if self.fc_norm is not None: + return self.fc_norm(x.mean(1)) + else: + return x[:, 0] + return x + + def forward(self, x, return_all_features=False): + if return_all_features: + return self.forward_features(x, return_all_features) + x = self.forward_features(x) + x = self.head(x) + return x diff --git a/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/factory.py b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..535eeacb7d780d473fad6cf360e38254af86cb13 --- /dev/null +++ b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/factory.py @@ -0,0 +1,528 @@ +import json +import logging +import os +import pathlib +import re +from copy import deepcopy +from pathlib import Path +from typing import Optional, Tuple, Union, Dict, Any +import torch + +try: + import deepspeed +except ImportError: + deepspeed = None + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .model import CLIP, CustomCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict, get_cast_dtype +from .openai import load_openai_model +from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model +from .transform import image_transform +from .tokenizer import HFTokenizer, tokenize +from .utils import resize_clip_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed, resize_eva_pos_embed + + +_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] +_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] + + +def _rescan_model_configs(): + global _MODEL_CONFIGS + + config_ext = (".json",) + config_files = [] + for config_path in _MODEL_CONFIG_PATHS: + if config_path.is_file() and config_path.suffix in config_ext: + config_files.append(config_path) + elif config_path.is_dir(): + for ext in config_ext: + config_files.extend(config_path.glob(f"*{ext}")) + + for cf in config_files: + with open(cf, "r", encoding="utf8") as f: + model_cfg = json.load(f) + if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")): + _MODEL_CONFIGS[cf.stem] = model_cfg + + _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))) + + +_rescan_model_configs() # initial populate of model config registry + + +def list_models(): + """enumerate available model architectures based on config files""" + return list(_MODEL_CONFIGS.keys()) + + +def add_model_config(path): + """add model config path or file and update registry""" + if not isinstance(path, Path): + path = Path(path) + _MODEL_CONFIG_PATHS.append(path) + _rescan_model_configs() + + +def get_model_config(model_name): + if model_name in _MODEL_CONFIGS: + return deepcopy(_MODEL_CONFIGS[model_name]) + else: + return None + + +def get_tokenizer(model_name): + config = get_model_config(model_name) + tokenizer = HFTokenizer(config["text_cfg"]["hf_tokenizer_name"]) if "hf_tokenizer_name" in config["text_cfg"] else tokenize + return tokenizer + + +# loading openai CLIP weights when is_openai=True for training +def load_state_dict(checkpoint_path: str, map_location: str = "cpu", model_key: str = "model|module|state_dict", is_openai: bool = False, skip_list: list = []): + if is_openai: + model = torch.jit.load(checkpoint_path, map_location="cpu").eval() + state_dict = model.state_dict() + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + else: + checkpoint = torch.load(checkpoint_path, map_location=map_location) + for mk in model_key.split("|"): + if isinstance(checkpoint, dict) and mk in checkpoint: + state_dict = checkpoint[mk] + break + else: + state_dict = checkpoint + if next(iter(state_dict.items()))[0].startswith("module"): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + for k in skip_list: + if k in list(state_dict.keys()): + logging.info(f"Removing key {k} from pretrained checkpoint") + del state_dict[k] + + if os.getenv("RoPE") == "1": + for k in list(state_dict.keys()): + if "freqs_cos" in k or "freqs_sin" in k: + del state_dict[k] + return state_dict + + +def load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=True): + state_dict = load_state_dict(checkpoint_path, model_key=model_key, is_openai=False) + # detect old format and make compatible with new format + if "positional_embedding" in state_dict and not hasattr(model, "positional_embedding"): + state_dict = convert_to_custom_text_state_dict(state_dict) + if "text.logit_scale" in state_dict and hasattr(model, "logit_scale"): + state_dict["logit_scale"] = state_dict["text.logit_scale"] + del state_dict["text.logit_scale"] + + # resize_clip_pos_embed for CLIP and open CLIP + if "visual.positional_embedding" in state_dict: + resize_clip_pos_embed(state_dict, model) + # specified to eva_vit_model + elif "visual.pos_embed" in state_dict: + resize_evaclip_pos_embed(state_dict, model) + + # resize_clip_pos_embed(state_dict, model) + incompatible_keys = model.load_state_dict(state_dict, strict=strict) + logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}") + return incompatible_keys + + +def load_clip_visual_state_dict(checkpoint_path: str, map_location: str = "cpu", is_openai: bool = False, skip_list: list = []): + state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list) + + for k in list(state_dict.keys()): + if not k.startswith("visual."): + del state_dict[k] + for k in list(state_dict.keys()): + if k.startswith("visual."): + new_k = k[7:] + state_dict[new_k] = state_dict[k] + del state_dict[k] + return state_dict + + +def load_clip_text_state_dict(checkpoint_path: str, map_location: str = "cpu", is_openai: bool = False, skip_list: list = []): + state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list) + + for k in list(state_dict.keys()): + if k.startswith("visual."): + del state_dict[k] + return state_dict + + +def get_pretrained_tag(pretrained_model): + pretrained_model = pretrained_model.lower() + if "laion" in pretrained_model or "open_clip" in pretrained_model: + return "open_clip" + elif "openai" in pretrained_model: + return "clip" + elif "eva" in pretrained_model and "clip" in pretrained_model: + return "eva_clip" + else: + return "other" + + +def load_zero_partitions(model, state_dict, is_deepspeed_zero3_enabled, pretrained_model_path, ignore_mismatched_sizes=False): + """ + adept from pytorch lightning and transformers + with deepspeed.zero.Init(): + model = MyModel() + state_dict = torch.load(model_path, map_location="cpu") + load_zero_partitions(model, prefix="") + """ + + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + model_state_dict = model.state_dict() + expected_keys = list(model_state_dict.keys()) + loaded_keys = list(state_dict.keys()) + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not + # matching the weights in the model. + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape: + mismatched_keys.append((checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)) + del state_dict[checkpoint_key] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) + if is_deepspeed_zero3_enabled: + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): + if torch.distributed.get_rank() == 0: + module._load_from_state_dict(*args) + else: + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + # Make sure we are able to load base models as well as derived models (with heads) + start_prefix = "" + model_to_load = model + load(model_to_load, prefix=start_prefix) + del state_dict + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + if len(unexpected_keys) > 0: + logging.warning( + f"Some weights of the model checkpoint at {pretrained_model_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" + " with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" + " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." + ) + else: + logging.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logging.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logging.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_path}.\nIf your task is similar to the task the model of the checkpoint" + f" was trained on, you can already use {model.__class__.__name__} for predictions without further" + " training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join([f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" for key, shape1, shape2 in mismatched_keys]) + logging.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" + " to use it for predictions and inference." + ) + + +def load_pretrained_checkpoint(model, visual_checkpoint_path, text_checkpoint_path, strict=True, visual_model=None, text_model=None, model_key="model|module|state_dict", skip_list=[]): + visual_tag = get_pretrained_tag(visual_model) + text_tag = get_pretrained_tag(text_model) + + logging.info(f"num of model state_dict keys: {len(model.state_dict().keys())}") + visual_incompatible_keys, text_incompatible_keys = None, None + if visual_checkpoint_path: + if visual_tag == "eva_clip" or visual_tag == "open_clip": + visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=False, skip_list=skip_list) + elif visual_tag == "clip": + visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list) + else: + visual_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list) + + # resize_clip_pos_embed for CLIP and open CLIP + if "positional_embedding" in visual_state_dict: + resize_visual_pos_embed(visual_state_dict, model) + # specified to EVA model + elif "pos_embed" in visual_state_dict: + resize_eva_pos_embed(visual_state_dict, model) + + visual_incompatible_keys = model.visual.load_state_dict(visual_state_dict, strict=strict) + logging.info(f"num of loaded visual_state_dict keys: {len(visual_state_dict.keys())}") + logging.info(f"visual_incompatible_keys.missing_keys: {visual_incompatible_keys.missing_keys}") + + if text_checkpoint_path: + if text_tag == "eva_clip" or text_tag == "open_clip": + text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=False, skip_list=skip_list) + elif text_tag == "clip": + text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=True, skip_list=skip_list) + else: + text_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list) + + text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict) + + logging.info(f"num of loaded text_state_dict keys: {len(text_state_dict.keys())}") + logging.info(f"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}") + + return visual_incompatible_keys, text_incompatible_keys + + +def create_model( + model_name: str, + pretrained: Optional[str] = None, + precision: str = "fp32", + device: Union[str, torch.device] = "cpu", + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_clip: bool = False, + force_patch_dropout: Optional[float] = None, + pretrained_image: str = "", + pretrained_text: str = "", + pretrained_hf: bool = True, + pretrained_visual_model: str = None, + pretrained_text_model: str = None, + cache_dir: Optional[str] = None, + skip_list: list = [], +): + model_name = model_name.replace("/", "-") # for callers using old naming with / in ViT names + if isinstance(device, str): + device = torch.device(device) + + if pretrained and pretrained.lower() == "openai": + logging.info(f"Loading pretrained {model_name} from OpenAI.") + model = load_openai_model( + model_name, + precision=precision, + device=device, + jit=jit, + cache_dir=cache_dir, + ) + else: + model_cfg = get_model_config(model_name) + if model_cfg is not None: + logging.info(f"Loaded {model_name} model config.") + else: + logging.error(f"Model config for {model_name} not found; available models {list_models()}.") + raise RuntimeError(f"Model config for {model_name} not found.") + + if "rope" in model_cfg.get("vision_cfg", {}): + if model_cfg["vision_cfg"]["rope"]: + os.environ["RoPE"] = "1" + else: + os.environ["RoPE"] = "0" + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + if force_patch_dropout is not None: + # override the default patch dropout value + model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout + + cast_dtype = get_cast_dtype(precision) + custom_clip = model_cfg.pop("custom_text", False) or force_custom_clip or ("hf_model_name" in model_cfg["text_cfg"]) + + if custom_clip: + if "hf_model_name" in model_cfg.get("text_cfg", {}): + model_cfg["text_cfg"]["hf_model_pretrained"] = pretrained_hf + model = CustomCLIP(**model_cfg, cast_dtype=cast_dtype) + else: + model = CLIP(**model_cfg, cast_dtype=cast_dtype) + + pretrained_cfg = {} + if pretrained: + checkpoint_path = "" + pretrained_cfg = get_pretrained_cfg(model_name, pretrained) + if pretrained_cfg: + checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) + elif os.path.exists(pretrained): + checkpoint_path = pretrained + + if checkpoint_path: + logging.info(f"Loading pretrained {model_name} weights ({pretrained}).") + load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=False) + else: + error_str = f"Pretrained weights ({pretrained}) not found for model {model_name}." f"Available pretrained tags ({list_pretrained_tags_by_model(model_name)}." + logging.warning(error_str) + raise RuntimeError(error_str) + else: + visual_checkpoint_path = "" + text_checkpoint_path = "" + + if pretrained_image: + pretrained_visual_model = pretrained_visual_model.replace("/", "-") # for callers using old naming with / in ViT names + pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image) + if "timm_model_name" in model_cfg.get("vision_cfg", {}): + # pretrained weight loading for timm models set via vision_cfg + model_cfg["vision_cfg"]["timm_model_pretrained"] = True + elif pretrained_image_cfg: + visual_checkpoint_path = download_pretrained(pretrained_image_cfg, cache_dir=cache_dir) + elif os.path.exists(pretrained_image): + visual_checkpoint_path = pretrained_image + else: + logging.warning(f"Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.") + raise RuntimeError(f"Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.") + + if pretrained_text: + pretrained_text_model = pretrained_text_model.replace("/", "-") # for callers using old naming with / in ViT names + pretrained_text_cfg = get_pretrained_cfg(pretrained_text_model, pretrained_text) + if pretrained_image_cfg: + text_checkpoint_path = download_pretrained(pretrained_text_cfg, cache_dir=cache_dir) + elif os.path.exists(pretrained_text): + text_checkpoint_path = pretrained_text + else: + logging.warning(f"Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.") + raise RuntimeError(f"Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.") + + if visual_checkpoint_path: + logging.info(f"Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).") + if text_checkpoint_path: + logging.info(f"Loading pretrained {model_name}.text weights ({text_checkpoint_path}).") + + if visual_checkpoint_path or text_checkpoint_path: + load_pretrained_checkpoint(model, visual_checkpoint_path, text_checkpoint_path, strict=False, visual_model=pretrained_visual_model, text_model=pretrained_text_model, model_key="model|module|state_dict", skip_list=skip_list) + + if "fp16" in precision or "bf16" in precision: + logging.info(f"convert precision to {precision}") + model = model.to(torch.bfloat16) if "bf16" in precision else model.to(torch.float16) + + # model.to(device=device) + + # set image / mean metadata from pretrained_cfg if available, or use default + model.visual.image_mean = pretrained_cfg.get("mean", None) or OPENAI_DATASET_MEAN + model.visual.image_std = pretrained_cfg.get("std", None) or OPENAI_DATASET_STD + + if jit: + model = torch.jit.script(model) + + return model + + +def create_model_and_transforms( + model_name: str, + pretrained: Optional[str] = None, + precision: str = "fp32", + device: Union[str, torch.device] = "cpu", + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_clip: bool = False, + force_patch_dropout: Optional[float] = None, + pretrained_image: str = "", + pretrained_text: str = "", + pretrained_hf: bool = True, + pretrained_visual_model: str = None, + pretrained_text_model: str = None, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + cache_dir: Optional[str] = None, + skip_list: list = [], +): + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_clip=force_custom_clip, + force_patch_dropout=force_patch_dropout, + pretrained_image=pretrained_image, + pretrained_text=pretrained_text, + pretrained_hf=pretrained_hf, + pretrained_visual_model=pretrained_visual_model, + pretrained_text_model=pretrained_text_model, + cache_dir=cache_dir, + skip_list=skip_list, + ) + + image_mean = image_mean or getattr(model.visual, "image_mean", None) + image_std = image_std or getattr(model.visual, "image_std", None) + preprocess_train = image_transform(model.visual.image_size, is_train=True, mean=image_mean, std=image_std) + preprocess_val = image_transform(model.visual.image_size, is_train=False, mean=image_mean, std=image_std) + + return model, preprocess_train, preprocess_val + + +def create_model_from_pretrained( + model_name: str, + pretrained: str, + precision: str = "fp32", + device: Union[str, torch.device] = "cpu", + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_clip: bool = False, + force_patch_dropout: Optional[float] = None, + return_transform: bool = True, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + cache_dir: Optional[str] = None, + is_frozen: bool = False, +): + if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained): + raise RuntimeError(f"{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}." f" Use open_clip.list_pretrained() to find one.") + + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_clip=force_custom_clip, + force_patch_dropout=force_patch_dropout, + cache_dir=cache_dir, + ) + + if is_frozen: + for param in model.parameters(): + param.requires_grad = False + + if not return_transform: + return model + + image_mean = image_mean or getattr(model.visual, "image_mean", None) + image_std = image_std or getattr(model.visual, "image_std", None) + preprocess = image_transform(model.visual.image_size, is_train=False, mean=image_mean, std=image_std) + + return model, preprocess diff --git a/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_configs.py b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..ddd2c672fdcc895ce2e437f6da5ee439c09840da --- /dev/null +++ b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_configs.py @@ -0,0 +1,57 @@ +# HF architecture dict: +arch_dict = { + # https://huggingface.co/docs/transformers/model_doc/roberta#roberta + "roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings", + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig + "xlm-roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings", + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 + "mt5": { + "config_names": { + # unlimited seqlen + # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 + # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 + "context_length": "", + "vocab_size": "vocab_size", + "width": "d_model", + "heads": "num_heads", + "layers": "num_layers", + "layer_attr": "block", + "token_embeddings_attr": "embed_tokens", + }, + "pooler": "mean_pooler", + }, + "bert": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings", + }, + "pooler": "mean_pooler", + }, +} diff --git a/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_model.py b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a156624bad999775be6dc2741be648d3d2c15c67 --- /dev/null +++ b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_model.py @@ -0,0 +1,240 @@ +""" huggingface model adapter + +Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. +""" + +import re + +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch import TensorType + +try: + import transformers + from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig + from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions +except ImportError as e: + transformers = None + + class BaseModelOutput: + pass + + class PretrainedConfig: + pass + + +from .hf_configs import arch_dict + + +# utils +def _camel2snake(s): + return re.sub(r"(? TensorType: + # image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device) + # attn_mask = (x != self.config.pad_token_id).long() + # out = self.transformer( + # input_ids=x, + # attention_mask=attn_mask, + # encoder_hidden_states = image_embeds, + # encoder_attention_mask = image_atts, + # ) + # pooled_out = self.pooler(out, attn_mask) + + # return self.itm_proj(pooled_out) + + def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None): + if masked_indices is None: + masked_indices = torch.bernoulli(probability_matrix).bool() + + masked_indices[input_ids == self.tokenizer.pad_token_id] = False + masked_indices[input_ids == self.tokenizer.cls_token_id] = False + + if targets is not None: + targets[~masked_indices] = -100 # We only compute loss on masked tokens + + # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices + input_ids[indices_replaced] = self.tokenizer.mask_token_id + + # 10% of the time, we replace masked input tokens with random word + indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced + random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device) + input_ids[indices_random] = random_words[indices_random] + # The rest of the time (10% of the time) we keep the masked input tokens unchanged + + if targets is not None: + return input_ids, targets + else: + return input_ids + + def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25): + labels = input_ids.clone() + attn_mask = (input_ids != self.config.pad_token_id).long() + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(input_ids.device) + vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"]) + probability_matrix = torch.full(labels.shape, mlm_probability) + input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels, probability_matrix=probability_matrix) + mlm_output = self.transformer( + input_ids, + attention_mask=attn_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + labels=labels, + ) + return mlm_output.loss + # mlm_output = self.transformer(input_ids, + # attention_mask = attn_mask, + # encoder_hidden_states = image_embeds, + # encoder_attention_mask = image_atts, + # return_dict = True, + # ).last_hidden_state + # logits = self.mlm_proj(mlm_output) + + # # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size) + # logits = logits[:, 1:, :].contiguous().view(-1, vocab_size) + # labels = labels[:, 1:].contiguous().view(-1) + + # mlm_loss = F.cross_entropy( + # logits, + # labels, + # # label_smoothing=0.1, + # ) + # return mlm_loss + + def forward(self, x: TensorType) -> TensorType: + attn_mask = (x != self.config.pad_token_id).long() + out = self.transformer(input_ids=x, attention_mask=attn_mask) + pooled_out = self.pooler(out, attn_mask) + + return self.proj(pooled_out) + + def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): + if not unlocked_layers: # full freezing + for n, p in self.transformer.named_parameters(): + p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False + return + + encoder = self.transformer.encoder if hasattr(self.transformer, "encoder") else self.transformer + layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"]) + print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model") + embeddings = getattr(self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"]) + modules = [embeddings, *layer_list][:-unlocked_layers] + # freeze layers + for module in modules: + for n, p in module.named_parameters(): + p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.gradient_checkpointing_enable() + + def get_num_layers(self): + encoder = self.transformer.encoder if hasattr(self.transformer, "encoder") else self.transformer + layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"]) + return len(layer_list) + + def init_parameters(self): + pass diff --git a/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/loss.py b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..ec4c995071fd54545d56427fa06c1789467da8e9 --- /dev/null +++ b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/loss.py @@ -0,0 +1,123 @@ +import math +import torch +import torch.nn as nn +from torch.nn import functional as F + +try: + import torch.distributed.nn + from torch import distributed as dist + + has_distributed = True +except ImportError: + has_distributed = False + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +from timm.loss import LabelSmoothingCrossEntropy + + +def gather_features(image_features, text_features, local_loss=False, gather_with_grad=False, rank=0, world_size=1, use_horovod=False): + assert has_distributed, "torch.distributed did not import correctly, please use a PyTorch version with support." + if use_horovod: + assert hvd is not None, "Please install horovod" + if gather_with_grad: + all_image_features = hvd.allgather(image_features) + all_text_features = hvd.allgather(text_features) + else: + with torch.no_grad(): + all_image_features = hvd.allgather(image_features) + all_text_features = hvd.allgather(text_features) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) + gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) + gathered_image_features[rank] = image_features + gathered_text_features[rank] = text_features + all_image_features = torch.cat(gathered_image_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + else: + # We gather tensors from all gpus + if gather_with_grad: + all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) + all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) + # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0) + # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0) + else: + gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] + gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] + dist.all_gather(gathered_image_features, image_features) + dist.all_gather(gathered_text_features, text_features) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_image_features[rank] = image_features + gathered_text_features[rank] = text_features + all_image_features = torch.cat(gathered_image_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + + return all_image_features, all_text_features + + +class ClipLoss(nn.Module): + + def __init__( + self, + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + smoothing=0.0, + ): + super().__init__() + self.local_loss = local_loss + self.gather_with_grad = gather_with_grad + self.cache_labels = cache_labels + self.rank = rank + self.world_size = world_size + self.use_horovod = use_horovod + self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None + + # cache state + self.prev_num_logits = 0 + self.labels = {} + + def forward(self, image_features, text_features, logit_scale=1.0): + device = image_features.device + if self.world_size > 1: + all_image_features, all_text_features = gather_features(image_features, text_features, self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) + + if self.local_loss: + logits_per_image = logit_scale * image_features @ all_text_features.T + logits_per_text = logit_scale * text_features @ all_image_features.T + else: + logits_per_image = logit_scale * all_image_features @ all_text_features.T + logits_per_text = logits_per_image.T + else: + logits_per_image = logit_scale * image_features @ text_features.T + logits_per_text = logit_scale * text_features @ image_features.T + # calculated ground-truth and cache if enabled + num_logits = logits_per_image.shape[0] + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + + if self.label_smoothing_cross_entropy: + total_loss = (self.label_smoothing_cross_entropy(logits_per_image, labels) + self.label_smoothing_cross_entropy(logits_per_text, labels)) / 2 + else: + total_loss = (F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels)) / 2 + + acc = None + i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image) + t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text) + acc = {"i2t": i2t_acc, "t2i": t2i_acc} + return total_loss, acc diff --git a/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/model.py b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..b2f3a2317d8dcd512a2f0019dfd7cbdcae79e3b8 --- /dev/null +++ b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/model.py @@ -0,0 +1,429 @@ +""" CLIP Model + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import os +from dataclasses import dataclass +from typing import Optional, Tuple, Union +from functools import partial + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +try: + from .hf_model import HFTextEncoder +except: + HFTextEncoder = None +from .modified_resnet import ModifiedResNet +from .timm_model import TimmModel +from .eva_vit_model import EVAVisionTransformer +from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer + +try: + from apex.normalization import FusedLayerNorm +except: + FusedLayerNorm = LayerNorm + # print("Please 'pip install apex'") + +try: + import xformers.ops as xops +except ImportError: + xops = None + # print("Please 'pip install xformers'") + + +class RMSnorm(nn.Module): + """ + adepted from transformers T5LayerNorm + """ + + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +@dataclass +class CLIPVisionCfg: + layers: Union[Tuple[int, int, int, int], int] = 12 + width: int = 768 + head_width: int = 64 + mlp_ratio: float = 4.0 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + ls_init_value: Optional[float] = None # layer scale initial value + patch_dropout: float = 0.0 # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results + global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) + drop_path_rate: Optional[float] = None # drop path rate + timm_model_name: str = None # a valid model name overrides layers, width, patch_size + timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model + timm_pool: str = "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + timm_proj: str = "linear" # linear projection for timm model output ('linear', 'mlp', '') + timm_proj_bias: bool = False # enable bias final projection + eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size + qkv_bias: bool = True + fusedLN: bool = False + xattn: bool = False + postnorm: bool = False + rope: bool = False + pt_hw_seq_len: int = 16 # 224/14 + intp_freq: bool = False + naiveswiglu: bool = False + subln: bool = False + use_rms_norm: bool = False + + +@dataclass +class CLIPTextCfg: + context_length: int = 77 + vocab_size: int = 49408 + width: int = 512 + heads: int = 8 + layers: int = 12 + ls_init_value: Optional[float] = None # layer scale initial value + hf_model_name: str = None + hf_tokenizer_name: str = None + hf_model_pretrained: bool = True + proj: str = "mlp" + pooler_type: str = "mean_pooler" + masked_language_modeling: bool = False + fusedLN: bool = False + xattn: bool = False + attn_mask: bool = True + + +def get_cast_dtype(precision: str): + cast_dtype = None + if precision == "bf16": + cast_dtype = torch.bfloat16 + elif precision == "fp16": + cast_dtype = torch.float16 + return cast_dtype + + +def _build_vision_tower(embed_dim: int, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None): + if isinstance(vision_cfg, dict): + vision_cfg = CLIPVisionCfg(**vision_cfg) + + # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more + # memory efficient in recent PyTorch releases (>= 1.10). + # NOTE: timm models always use native GELU regardless of quick_gelu flag. + act_layer = QuickGELU if quick_gelu else nn.GELU + + if vision_cfg.eva_model_name: + vision_heads = vision_cfg.width // vision_cfg.head_width + + norm_layer = RMSnorm if vision_cfg.use_rms_norm else LayerNorm + + visual = EVAVisionTransformer( + img_size=vision_cfg.image_size, + patch_size=vision_cfg.patch_size, + num_classes=embed_dim, + use_mean_pooling=vision_cfg.global_average_pool, # False + init_values=vision_cfg.ls_init_value, + patch_dropout=vision_cfg.patch_dropout, + embed_dim=vision_cfg.width, + depth=vision_cfg.layers, + num_heads=vision_heads, + mlp_ratio=vision_cfg.mlp_ratio, + qkv_bias=vision_cfg.qkv_bias, + drop_path_rate=vision_cfg.drop_path_rate, + norm_layer=partial(norm_layer, eps=1e-6), + xattn=vision_cfg.xattn, + rope=vision_cfg.rope, + postnorm=vision_cfg.postnorm, + pt_hw_seq_len=vision_cfg.pt_hw_seq_len, # 224/14 + intp_freq=vision_cfg.intp_freq, + naiveswiglu=vision_cfg.naiveswiglu, + subln=vision_cfg.subln, + ) + elif vision_cfg.timm_model_name: + visual = TimmModel( + vision_cfg.timm_model_name, pretrained=vision_cfg.timm_model_pretrained, pool=vision_cfg.timm_pool, proj=vision_cfg.timm_proj, proj_bias=vision_cfg.timm_proj_bias, embed_dim=embed_dim, image_size=vision_cfg.image_size + ) + act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models + elif isinstance(vision_cfg.layers, (tuple, list)): + vision_heads = vision_cfg.width * 32 // vision_cfg.head_width + visual = ModifiedResNet(layers=vision_cfg.layers, output_dim=embed_dim, heads=vision_heads, image_size=vision_cfg.image_size, width=vision_cfg.width) + else: + vision_heads = vision_cfg.width // vision_cfg.head_width + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + visual = VisionTransformer( + image_size=vision_cfg.image_size, + patch_size=vision_cfg.patch_size, + width=vision_cfg.width, + layers=vision_cfg.layers, + heads=vision_heads, + mlp_ratio=vision_cfg.mlp_ratio, + ls_init_value=vision_cfg.ls_init_value, + patch_dropout=vision_cfg.patch_dropout, + global_average_pool=vision_cfg.global_average_pool, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return visual + + +def _build_text_tower( + embed_dim: int, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + if isinstance(text_cfg, dict): + text_cfg = CLIPTextCfg(**text_cfg) + + if text_cfg.hf_model_name: + text = HFTextEncoder(text_cfg.hf_model_name, output_dim=embed_dim, tokenizer_name=text_cfg.hf_tokenizer_name, proj=text_cfg.proj, pooler_type=text_cfg.pooler_type, masked_language_modeling=text_cfg.masked_language_modeling) + else: + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = LayerNorm + + text = TextTransformer( + context_length=text_cfg.context_length, + vocab_size=text_cfg.vocab_size, + width=text_cfg.width, + heads=text_cfg.heads, + layers=text_cfg.layers, + ls_init_value=text_cfg.ls_init_value, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=FusedLayerNorm if text_cfg.fusedLN else norm_layer, + xattn=text_cfg.xattn, + attn_mask=text_cfg.attn_mask, + ) + return text + + +class CLIP(nn.Module): + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + + text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.transformer = text.transformer + self.vocab_size = text.vocab_size + self.token_embedding = text.token_embedding + self.positional_embedding = text.positional_embedding + self.ln_final = text.ln_final + self.text_projection = text.text_projection + self.register_buffer("attn_mask", text.attn_mask, persistent=False) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.transformer.grad_checkpointing = enable + + @torch.jit.ignore + def no_weight_decay(self): + return {"logit_scale"} + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + cast_dtype = self.transformer.get_cast_dtype() + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + return F.normalize(x, dim=-1) if normalize else x + + def forward(self, image, text): + image_features = self.encode_image(image, normalize=True) + text_features = self.encode_text(text, normalize=True) + return image_features, text_features, self.logit_scale.exp() + + +class CustomCLIP(nn.Module): + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + itm_task: bool = False, + ): + super().__init__() + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): + self.text.lock(unlocked_layers, freeze_layer_norm) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + + @torch.jit.ignore + def no_weight_decay(self): + return {"logit_scale"} + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + features = self.text(text) + return F.normalize(features, dim=-1) if normalize else features + + def forward(self, image, text): + image_features = self.encode_image(image, normalize=True) + text_features = self.encode_text(text, normalize=True) + return image_features, text_features, self.logit_scale.exp() + + +def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): + """Convert applicable model parameters to low-precision (bf16 or fp16)""" + + def _convert_weights(l): + + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.to(dtype) + if l.bias is not None: + l.bias.data = l.bias.data.to(dtype) + + if isinstance(l, (nn.MultiheadAttention, Attention)): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr, None) + if tensor is not None: + tensor.data = tensor.data.to(dtype) + + if isinstance(l, nn.Parameter): + l.data = l.data.to(dtype) + + for name in ["text_projection", "proj"]: + if hasattr(l, name) and isinstance(l, nn.Parameter): + attr = getattr(l, name, None) + if attr is not None: + attr.data = attr.data.to(dtype) + + model.apply(_convert_weights) + + +convert_weights_to_fp16 = convert_weights_to_lp # backwards compat + + +# used to maintain checkpoint compatibility +def convert_to_custom_text_state_dict(state_dict: dict): + if "text_projection" in state_dict: + # old format state_dict, move text tower -> .text + new_state_dict = {} + for k, v in state_dict.items(): + if any(k.startswith(p) for p in ("text_projection", "positional_embedding", "token_embedding", "transformer", "ln_final", "logit_scale")): + k = "text." + k + new_state_dict[k] = v + return new_state_dict + return state_dict + + +def build_model_from_openai_state_dict( + state_dict: dict, + quick_gelu=True, + cast_dtype=torch.float16, +): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_size = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width**2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_size = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + vision_cfg = CLIPVisionCfg( + layers=vision_layers, + width=vision_width, + patch_size=vision_patch_size, + image_size=image_size, + ) + text_cfg = CLIPTextCfg(context_length=context_length, vocab_size=vocab_size, width=transformer_width, heads=transformer_heads, layers=transformer_layers) + model = CLIP( + embed_dim, + vision_cfg=vision_cfg, + text_cfg=text_cfg, + quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU + cast_dtype=cast_dtype, + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + + convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 + model.load_state_dict(state_dict) + return model.eval() + + +def trace_model(model, batch_size=256, device=torch.device("cpu")): + model.eval() + image_size = model.visual.image_size + example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) + example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) + model = torch.jit.trace_module(model, inputs=dict(forward=(example_images, example_text), encode_text=(example_text,), encode_image=(example_images,))) + model.visual.image_size = image_size + return model diff --git a/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/modified_resnet.py b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/modified_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9f29f84540f92ab7a216c02e77eeca5a6e6653a3 --- /dev/null +++ b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/modified_resnet.py @@ -0,0 +1,179 @@ +from collections import OrderedDict + +import torch +from torch import nn +from torch.nn import functional as F + +from .utils import freeze_batch_norm_2d + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.act1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.act2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.act3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([("-1", nn.AvgPool2d(stride)), ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), ("1", nn.BatchNorm2d(planes * self.expansion))])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.act1(self.bn1(self.conv1(x))) + out = self.act2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.act3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0.0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, image_size=224, width=64): + super().__init__() + self.output_dim = output_dim + self.image_size = image_size + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.act1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.act2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.act3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) + + self.init_parameters() + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def init_parameters(self): + if self.attnpool is not None: + std = self.attnpool.c_proj.in_features**-0.5 + nn.init.normal_(self.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert unlocked_groups == 0, "partial locking not currently supported for this model" + for param in self.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + # FIXME support for non-transformer + pass + + def stem(self, x): + x = self.act1(self.bn1(self.conv1(x))) + x = self.act2(self.bn2(self.conv2(x))) + x = self.act3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + def forward(self, x): + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x diff --git a/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/openai.py b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..9fbf6fc7ecfe2ab3580b9d8af4793e09bc12101f --- /dev/null +++ b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/openai.py @@ -0,0 +1,144 @@ +""" OpenAI pretrained model functions + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import os +import warnings +from typing import List, Optional, Union + +import torch + +from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype +from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url + +__all__ = ["list_openai_models", "load_openai_model"] + + +def list_openai_models() -> List[str]: + """Returns the names of available CLIP models""" + return list_pretrained_models_by_tag("openai") + + +def load_openai_model( + name: str, + precision: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, + jit: bool = True, + cache_dir: Optional[str] = None, +): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + precision: str + Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. + device : Union[str, torch.device] + The device to put the loaded model + jit : bool + Whether to load the optimized JIT model (default) or more hackable non-JIT model. + cache_dir : Optional[str] + The directory to cache the downloaded model weights + + Returns + ------- + model : torch.nn.Module + The CLIP model + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + if precision is None: + precision = "fp32" if device == "cpu" else "fp16" + + if get_pretrained_url(name, "openai"): + model_path = download_pretrained_from_url(get_pretrained_url(name, "openai"), cache_dir=cache_dir) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(model_path, map_location="cpu") + + if not jit: + # Build a non-jit model from the OpenAI jitted model state dict + cast_dtype = get_cast_dtype(precision) + try: + model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) + except KeyError: + sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} + model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) + + # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use + model = model.to(device) + if precision.startswith("amp") or precision == "fp32": + model.float() + elif precision == "bf16": + convert_weights_to_lp(model, dtype=torch.bfloat16) + + return model + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 (typically for CPU) + if precision == "fp32": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + model.float() + + # ensure image_size attr available at consistent location for both jit and non-jit + model.visual.image_size = model.input_resolution.item() + return model diff --git a/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/pretrained.py b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..a603b6047bf86decdb9f5ae247d5cea9a555cf50 --- /dev/null +++ b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/pretrained.py @@ -0,0 +1,314 @@ +import hashlib +import os +import urllib +import warnings +from typing import Dict, Union + +from tqdm import tqdm + +try: + from huggingface_hub import hf_hub_download + + _has_hf_hub = True +except ImportError: + hf_hub_download = None + _has_hf_hub = False + + +def _pcfg(url="", hf_hub="", filename="", mean=None, std=None): + return dict( + url=url, + hf_hub=hf_hub, + mean=mean, + std=std, + ) + + +_VITB32 = dict( + openai=_pcfg("https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), + laion2b_e16=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), + laion2b_s34b_b79k=_pcfg(hf_hub="laion/CLIP-ViT-B-32-laion2B-s34B-b79K/"), +) + +_VITB32_quickgelu = dict( + openai=_pcfg("https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), +) + +_VITB16 = dict( + openai=_pcfg("https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), + laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), + laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), + laion2b_s34b_b88k=_pcfg(hf_hub="laion/CLIP-ViT-B-16-laion2B-s34B-b88K/"), +) + +_EVAB16 = dict( + eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt"), + eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt"), + eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt"), + eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt"), +) + +_VITB16_PLUS_240 = dict( + laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), + laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), +) + +_VITL14 = dict( + openai=_pcfg("https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), + laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), + laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), + laion2b_s32b_b82k=_pcfg(hf_hub="laion/CLIP-ViT-L-14-laion2B-s32B-b82K/", mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), +) + +_EVAL14 = dict( + eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_L_psz14.pt"), + eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_L_psz14.pt"), + eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt"), + eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt"), +) + +_VITL14_336 = dict( + openai=_pcfg("https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), +) + +_EVAL14_336 = dict( + eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt"), + eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt"), + eva_clip_224to336=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt"), + eva02_clip_224to336=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt"), +) + +_VITH14 = dict( + laion2b_s32b_b79k=_pcfg(hf_hub="laion/CLIP-ViT-H-14-laion2B-s32B-b79K/"), +) + +_VITg14 = dict( + laion2b_s12b_b42k=_pcfg(hf_hub="laion/CLIP-ViT-g-14-laion2B-s12B-b42K/"), + laion2b_s34b_b88k=_pcfg(hf_hub="laion/CLIP-ViT-g-14-laion2B-s34B-b88K/"), +) + +_EVAg14 = dict( + eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/"), + eva01=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_g_psz14.pt"), + eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt"), + eva01_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt"), +) + +_EVAg14_PLUS = dict( + eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/"), + eva01=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_g_psz14.pt"), + eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt"), + eva01_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt"), +) + +_VITbigG14 = dict( + laion2b_s39b_b160k=_pcfg(hf_hub="laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/"), +) + +_EVAbigE14 = dict( + eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"), + eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"), + eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt"), + eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt"), +) + +_EVAbigE14_PLUS = dict( + eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"), + eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"), + eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt"), + eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt"), +) + +_EVA_8B = dict( + eva=_pcfg(hf_hub="BAAI/EVA-CLIP-8B/EVA_8B_psz14.bin"), + eva_clip=_pcfg(hf_hub="BAAI/EVA-CLIP-8B/EVA_CLIP_8B_psz14_s9B.pt"), +) + +_EVA_8B_PLUS = dict( + eva_clip=_pcfg(hf_hub="BAAI/EVA-CLIP-8B-448/EVA_CLIP_8B_psz14_plus_s0.6B.pt"), +) + + +_PRETRAINED = { + # "ViT-B-32": _VITB32, + "OpenaiCLIP-B-32": _VITB32, + "OpenCLIP-B-32": _VITB32, + # "ViT-B-32-quickgelu": _VITB32_quickgelu, + "OpenaiCLIP-B-32-quickgelu": _VITB32_quickgelu, + "OpenCLIP-B-32-quickgelu": _VITB32_quickgelu, + # "ViT-B-16": _VITB16, + "OpenaiCLIP-B-16": _VITB16, + "OpenCLIP-B-16": _VITB16, + "EVA02-B-16": _EVAB16, + "EVA02-CLIP-B-16": _EVAB16, + # "ViT-B-16-plus-240": _VITB16_PLUS_240, + "OpenCLIP-B-16-plus-240": _VITB16_PLUS_240, + # "ViT-L-14": _VITL14, + "OpenaiCLIP-L-14": _VITL14, + "OpenCLIP-L-14": _VITL14, + "EVA02-L-14": _EVAL14, + "EVA02-CLIP-L-14": _EVAL14, + # "ViT-L-14-336": _VITL14_336, + "OpenaiCLIP-L-14-336": _VITL14_336, + "EVA02-CLIP-L-14-336": _EVAL14_336, + # "ViT-H-14": _VITH14, + # "ViT-g-14": _VITg14, + "OpenCLIP-H-14": _VITH14, + "OpenCLIP-g-14": _VITg14, + "EVA01-CLIP-g-14": _EVAg14, + "EVA01-CLIP-g-14-plus": _EVAg14_PLUS, + # "ViT-bigG-14": _VITbigG14, + "OpenCLIP-bigG-14": _VITbigG14, + "EVA02-CLIP-bigE-14": _EVAbigE14, + "EVA02-CLIP-bigE-14-plus": _EVAbigE14_PLUS, + "EVA-CLIP-8B": _EVA_8B, + "EVA-CLIP-8B-448": _EVA_8B_PLUS, + "EVA-CLIP-8B-plus": _EVA_8B_PLUS, +} + + +def _clean_tag(tag: str): + # normalize pretrained tags + return tag.lower().replace("-", "_") + + +def list_pretrained(as_str: bool = False): + """returns list of pretrained models + Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True + """ + return [":".join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] + + +def list_pretrained_models_by_tag(tag: str): + """return all models having the specified pretrain tag""" + models = [] + tag = _clean_tag(tag) + for k in _PRETRAINED.keys(): + if tag in _PRETRAINED[k]: + models.append(k) + return models + + +def list_pretrained_tags_by_model(model: str): + """return all pretrain tags for the specified model architecture""" + tags = [] + if model in _PRETRAINED: + tags.extend(_PRETRAINED[model].keys()) + return tags + + +def is_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return False + return _clean_tag(tag) in _PRETRAINED[model] + + +def get_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return {} + model_pretrained = _PRETRAINED[model] + return model_pretrained.get(_clean_tag(tag), {}) + + +def get_pretrained_url(model: str, tag: str): + cfg = get_pretrained_cfg(model, _clean_tag(tag)) + return cfg.get("url", "") + + +def download_pretrained_from_url( + url: str, + cache_dir: Union[str, None] = None, +): + if not cache_dir: + cache_dir = os.path.expanduser("~/.cache/clip") + os.makedirs(cache_dir, exist_ok=True) + filename = os.path.basename(url) + + if "openaipublic" in url: + expected_sha256 = url.split("/")[-2] + elif "mlfoundations" in url: + expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] + else: + expected_sha256 = "" + + download_target = os.path.join(cache_dir, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if expected_sha256: + if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + else: + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit="iB", unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def has_hf_hub(necessary=False): + if not _has_hf_hub and necessary: + # if no HF Hub module installed, and it is necessary to continue, raise error + raise RuntimeError("Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.") + return _has_hf_hub + + +def download_pretrained_from_hf( + model_id: str, + filename: str = "open_clip_pytorch_model.bin", + revision=None, + cache_dir: Union[str, None] = None, +): + has_hf_hub(True) + cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) + return cached_file + + +def download_pretrained( + cfg: Dict, + force_hf_hub: bool = False, + cache_dir: Union[str, None] = None, +): + target = "" + if not cfg: + return target + + download_url = cfg.get("url", "") + download_hf_hub = cfg.get("hf_hub", "") + if download_hf_hub and force_hf_hub: + # use HF hub even if url exists + download_url = "" + + if download_url: + target = download_pretrained_from_url(download_url, cache_dir=cache_dir) + elif download_hf_hub: + has_hf_hub(True) + # we assume the hf_hub entries in pretrained config combine model_id + filename in + # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and + # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. + model_id, filename = os.path.split(download_hf_hub) + if filename: + target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) + else: + target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + + return target diff --git a/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/rope.py b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..5fb3cce54e5ff26e53271834e65997714db3a1ad --- /dev/null +++ b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/rope.py @@ -0,0 +1,131 @@ +from math import pi +import torch +from torch import nn +from einops import rearrange, repeat +import logging + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), "invalid dimensions for broadcastable concatentation" + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim=dim) + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +class VisionRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + pt_seq_len, + ft_seq_len=None, + custom_freqs=None, + freqs_for="lang", + theta=10000, + max_freq=10, + num_freqs=1, + ): + super().__init__() + if custom_freqs: + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f"unknown modality {freqs_for}") + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs_h = torch.einsum("..., f -> ... f", t, freqs) + freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) + + freqs_w = torch.einsum("..., f -> ... f", t, freqs) + freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) + + freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1) + + self.register_buffer("freqs_cos", freqs.cos()) + self.register_buffer("freqs_sin", freqs.sin()) + + logging.info(f"Shape of rope freq: {self.freqs_cos.shape}") + + def forward(self, t, start_index=0): + rot_dim = self.freqs_cos.shape[-1] + end_index = start_index + rot_dim + assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" + t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] + t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) + + return torch.cat((t_left, t, t_right), dim=-1) + + +class VisionRotaryEmbeddingFast(nn.Module): + def __init__(self, dim, pt_seq_len, ft_seq_len=None, custom_freqs=None, freqs_for="lang", theta=10000, max_freq=10, num_freqs=1, patch_dropout=0.0): + super().__init__() + if custom_freqs: + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f"unknown modality {freqs_for}") + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs = torch.einsum("..., f -> ... f", t, freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) + + freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) + freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) + + self.patch_dropout = patch_dropout + + self.register_buffer("freqs_cos", freqs_cos) + self.register_buffer("freqs_sin", freqs_sin) + + logging.info(f"Shape of rope freq: {self.freqs_cos.shape}") + + def forward(self, t, patch_indices_keep=None): + if patch_indices_keep is not None: + batch = t.size()[0] + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + freqs_cos = repeat(self.freqs_cos, "i j -> n i m j", n=t.shape[0], m=t.shape[1]) + freqs_sin = repeat(self.freqs_sin, "i j -> n i m j", n=t.shape[0], m=t.shape[1]) + + freqs_cos = freqs_cos[batch_indices, patch_indices_keep] + freqs_cos = rearrange(freqs_cos, "n i m j -> n m i j") + freqs_sin = freqs_sin[batch_indices, patch_indices_keep] + freqs_sin = rearrange(freqs_sin, "n i m j -> n m i j") + + return t * freqs_cos + rotate_half(t) * freqs_sin + + return t * self.freqs_cos + rotate_half(t) * self.freqs_sin diff --git a/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/timm_model.py b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/timm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..65de78df2892a4cbe01b9eef68f07f5f79b8b367 --- /dev/null +++ b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/timm_model.py @@ -0,0 +1,114 @@ +""" timm model adapter + +Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. +""" + +import logging +from collections import OrderedDict + +import torch +import torch.nn as nn + +try: + import timm + from timm.models.layers import Mlp, to_2tuple + + try: + # old timm imports < 0.8.1 + from timm.models.layers.attention_pool2d import RotAttentionPool2d + from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d + except ImportError: + # new timm imports >= 0.8.1 + from timm.layers import RotAttentionPool2d + from timm.layers import AttentionPool2d as AbsAttentionPool2d +except ImportError: + timm = None + +from .utils import freeze_batch_norm_2d + + +class TimmModel(nn.Module): + """timm model adapter + # FIXME this adapter is a work in progress, may change in ways that break weight compat + """ + + def __init__(self, model_name, embed_dim, image_size=224, pool="avg", proj="linear", proj_bias=False, drop=0.0, pretrained=False): + super().__init__() + if timm is None: + raise RuntimeError("Please `pip install timm` to use timm models.") + + self.image_size = to_2tuple(image_size) + self.trunk = timm.create_model(model_name, pretrained=pretrained) + feat_size = self.trunk.default_cfg.get("pool_size", None) + feature_ndim = 1 if not feat_size else 2 + if pool in ("abs_attn", "rot_attn"): + assert feature_ndim == 2 + # if attn pooling used, remove both classifier and default pool + self.trunk.reset_classifier(0, global_pool="") + else: + # reset global pool if pool config set, otherwise leave as network default + reset_kwargs = dict(global_pool=pool) if pool else {} + self.trunk.reset_classifier(0, **reset_kwargs) + prev_chs = self.trunk.num_features + + head_layers = OrderedDict() + if pool == "abs_attn": + head_layers["pool"] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) + prev_chs = embed_dim + elif pool == "rot_attn": + head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim) + prev_chs = embed_dim + else: + assert proj, "projection layer needed if non-attention pooling is used." + + # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used + if proj == "linear": + head_layers["drop"] = nn.Dropout(drop) + head_layers["proj"] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) + elif proj == "mlp": + head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias)) + + self.head = nn.Sequential(head_layers) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + """lock modules + Args: + unlocked_groups (int): leave last n layer groups unlocked (default: 0) + """ + if not unlocked_groups: + # lock full model + for param in self.trunk.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self.trunk) + else: + # NOTE: partial freeze requires latest timm (master) branch and is subject to change + try: + # FIXME import here until API stable and in an official release + from timm.models.helpers import group_parameters, group_modules + except ImportError: + raise RuntimeError("Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`") + matcher = self.trunk.group_matcher() + gparams = group_parameters(self.trunk, matcher) + max_layer_id = max(gparams.keys()) + max_layer_id = max_layer_id - unlocked_groups + for group_idx in range(max_layer_id + 1): + group = gparams[group_idx] + for param in group: + self.trunk.get_parameter(param).requires_grad = False + if freeze_bn_stats: + gmodules = group_modules(self.trunk, matcher, reverse=True) + gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} + freeze_batch_norm_2d(self.trunk, gmodules) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + try: + self.trunk.set_grad_checkpointing(enable) + except Exception as e: + logging.warning("grad checkpointing not supported for this timm image tower, continuing without...") + + def forward(self, x): + x = self.trunk(x) + x = self.head(x) + return x diff --git a/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/tokenizer.py b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5f753e69bc8e24b607b0fa1378ebe236b3d47c27 --- /dev/null +++ b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/tokenizer.py @@ -0,0 +1,205 @@ +""" CLIP tokenizer + +Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import gzip +import html +import os +from functools import lru_cache +from typing import Union, List + +import ftfy +import regex as re +import torch + +# https://stackoverflow.com/q/62691279 +import os + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + if not special_tokens: + special_tokens = ["", ""] + else: + special_tokens = ["", ""] + special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t: t for t in special_tokens} + special = "|".join(special_tokens) + self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors="replace").replace("", " ") + return text + + +_tokenizer = SimpleTokenizer() + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder[""] + eot_token = _tokenizer.encoder[""] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + tokens[-1] = eot_token + result[i, : len(tokens)] = torch.tensor(tokens) + + return result + + +class HFTokenizer: + "HuggingFace tokenizer wrapper" + + def __init__(self, tokenizer_name: str): + from transformers import AutoTokenizer + + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor: + # same cleaning as for default tokenizer, except lowercasing + # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance + if isinstance(texts, str): + texts = [texts] + texts = [whitespace_clean(basic_clean(text)) for text in texts] + input_ids = self.tokenizer(texts, return_tensors="pt", max_length=context_length, padding="max_length", truncation=True).input_ids + return input_ids diff --git a/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/transform.py b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..8cad45a167ab85eba8f84eed7feaa132dedc48d4 --- /dev/null +++ b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/transform.py @@ -0,0 +1,104 @@ +from typing import Optional, Sequence, Tuple + +import torch +import torch.nn as nn +import torchvision.transforms.functional as F + +from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, CenterCrop + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD + + +class ResizeMaxSize(nn.Module): + + def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn="max", fill=0): + super().__init__() + if not isinstance(max_size, int): + raise TypeError(f"Size should be int. Got {type(max_size)}") + self.max_size = max_size + self.interpolation = interpolation + self.fn = min if fn == "min" else min + self.fill = fill + + def forward(self, img): + if isinstance(img, torch.Tensor): + height, width = img.shape[:2] + else: + width, height = img.size + scale = self.max_size / float(max(height, width)) + if scale != 1.0: + new_size = tuple(round(dim * scale) for dim in (height, width)) + img = F.resize(img, new_size, self.interpolation) + pad_h = self.max_size - new_size[0] + pad_w = self.max_size - new_size[1] + img = F.pad(img, padding=[pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2], fill=self.fill) + return img + + +def _convert_to_rgb(image): + return image.convert("RGB") + + +# class CatGen(nn.Module): +# def __init__(self, num=4): +# self.num = num +# def mixgen_batch(image, text): +# batch_size = image.shape[0] +# index = np.random.permutation(batch_size) + +# cat_images = [] +# for i in range(batch_size): +# # image mixup +# image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:] +# # text concat +# text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0] +# text = torch.stack(text) +# return image, text + + +def image_transform( + image_size: int, + is_train: bool, + mean: Optional[Tuple[float, ...]] = None, + std: Optional[Tuple[float, ...]] = None, + resize_longest_max: bool = False, + fill_color: int = 0, +): + mean = mean or OPENAI_DATASET_MEAN + if not isinstance(mean, (list, tuple)): + mean = (mean,) * 3 + + std = std or OPENAI_DATASET_STD + if not isinstance(std, (list, tuple)): + std = (std,) * 3 + + if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: + # for square size, pass size as int so that Resize() uses aspect preserving shortest edge + image_size = image_size[0] + + normalize = Normalize(mean=mean, std=std) + if is_train: + return Compose( + [ + RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), + _convert_to_rgb, + ToTensor(), + normalize, + ] + ) + else: + if resize_longest_max: + transforms = [ResizeMaxSize(image_size, fill=fill_color)] + else: + transforms = [ + Resize(image_size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(image_size), + ] + transforms.extend( + [ + _convert_to_rgb, + ToTensor(), + normalize, + ] + ) + return Compose(transforms) diff --git a/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/transformer.py b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..bd5ce4b6e5a3e8eadc0bacef4b270f829cdab762 --- /dev/null +++ b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/transformer.py @@ -0,0 +1,683 @@ +import os +import logging +from collections import OrderedDict +import math +from typing import Callable, Optional, Sequence +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +try: + from timm.models.layers import trunc_normal_ +except: + from timm.layers import trunc_normal_ + +from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast +from .utils import to_2tuple + +if os.getenv("ENV_TYPE") == "deepspeed": + try: + import deepspeed + from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint + except: + print("Please 'pip install deepspeed'") + deepspeed = None + from torch.utils.checkpoint import checkpoint +else: + from torch.utils.checkpoint import checkpoint + +try: + import xformers.ops as xops +except ImportError: + xops = None + # print("Please 'pip install xformers'") + + +class LayerNormFp32(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x: torch.Tensor): + output = F.layer_norm( + x.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(x) + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm (with cast back to input dtype).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1.0 + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}") + + def forward(self, x): + if not self.training or self.prob == 0.0: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + if self.training and os.getenv("RoPE") == "1": + return x, patch_indices_keep + + return x + + +def _in_projection_packed( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: Optional[torch.Tensor] = None, +): + """ + https://github.com/pytorch/pytorch/blob/db2a237763eb8693a20788be94f8c192e762baa8/torch/nn/functional.py#L4726 + """ + E = q.size(-1) + if k is v: + if q is k: + # self-attention + return F.linear(q, w, b).chunk(3, dim=-1) + else: + # encoder-decoder attention + w_q, w_kv = w.split([E, E * 2]) + if b is None: + b_q = b_kv = None + else: + b_q, b_kv = b.split([E, E * 2]) + return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1) + else: + w_q, w_k, w_v = w.chunk(3) + if b is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = b.chunk(3) + return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v) + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=True, scaled_cosine=False, scale_heads=False, logit_scale_max=math.log(1.0 / 0.01), attn_drop=0.0, proj_drop=0.0, xattn=False, rope=False): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.logit_scale_max = logit_scale_max + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None + + if self.scaled_cosine: + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = nn.Linear(dim, dim) + self.out_drop = nn.Dropout(proj_drop) + self.xattn = xattn + self.xattn_drop = attn_drop + self.rope = rope + + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + L, N, C = x.shape + q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) + if self.xattn: + q = q.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1) + + x = xops.memory_efficient_attention( + q, + k, + v, + p=self.xattn_drop, + scale=self.scale if self.logit_scale is None else None, + attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None, + ) + else: + q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + + if self.logit_scale is not None: + attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.view(N, self.num_heads, L, L) * logit_scale + attn = attn.view(-1, L, L) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-1, -2)) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + attn += attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = torch.bmm(attn, v) + + if self.head_scale is not None: + x = x.view(N, self.num_heads, L, C) * self.head_scale + x = x.view(-1, L, C) + x = x.transpose(0, 1).reshape(L, N, C) + x = self.out_proj(x) + x = self.out_drop(x) + return x + + +class CustomAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=True, scaled_cosine=True, scale_heads=False, logit_scale_max=math.log(1.0 / 0.01), attn_drop=0.0, proj_drop=0.0, xattn=False): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.logit_scale_max = logit_scale_max + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None + + if self.scaled_cosine: + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = nn.Linear(dim, dim) + self.out_drop = nn.Dropout(proj_drop) + self.xattn = xattn + self.xattn_drop = attn_drop + + def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias) + N_q, B_q, C_q = q.shape + N_k, B_k, C_k = k.shape + N_v, B_v, C_v = v.shape + if self.xattn: + # B, N, C -> B, N, num_heads, C + q = q.permute(1, 0, 2).reshape(B_q, N_q, self.num_heads, -1) + k = k.permute(1, 0, 2).reshape(B_k, N_k, self.num_heads, -1) + v = v.permute(1, 0, 2).reshape(B_v, N_v, self.num_heads, -1) + + x = xops.memory_efficient_attention(q, k, v, p=self.xattn_drop, scale=self.scale if self.logit_scale is None else None, attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None) + else: + # B*H, L, C + q = q.contiguous().view(N_q, B_q * self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(N_k, B_k * self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(N_v, B_v * self.num_heads, -1).transpose(0, 1) + + if self.logit_scale is not None: + # B*H, N_q, N_k + attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.view(B_q, self.num_heads, N_q, N_k) * logit_scale + attn = attn.view(-1, N_q, N_k) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-1, -2)) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + attn += attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = torch.bmm(attn, v) + + if self.head_scale is not None: + x = x.view(B_q, self.num_heads, N_q, C_q) * self.head_scale + x = x.view(-1, N_q, C_q) + x = x.transpose(0, 1).reshape(N_q, B_q, C_q) + x = self.out_proj(x) + x = self.out_drop(x) + return x + + +class CustomResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + scale_cosine_attn: bool = False, + scale_heads: bool = False, + scale_attn: bool = False, + scale_fc: bool = False, + cross_attn: bool = False, + xattn: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.ln_1_k = norm_layer(d_model) if cross_attn else self.ln_1 + self.ln_1_v = norm_layer(d_model) if cross_attn else self.ln_1 + self.attn = CustomAttention(d_model, n_head, qkv_bias=True, attn_drop=0.0, proj_drop=0.0, scaled_cosine=scale_cosine_attn, scale_heads=scale_heads, xattn=xattn) + + self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([("c_fc", nn.Linear(d_model, mlp_width)), ("ln", norm_layer(mlp_width) if scale_fc else nn.Identity()), ("gelu", act_layer()), ("c_proj", nn.Linear(mlp_width, d_model))])) + + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + q = q + self.ls_1(self.ln_attn(self.attn(self.ln_1(q), self.ln_1_k(k), self.ln_1_v(v), attn_mask=attn_mask))) + q = q + self.ls_2(self.mlp(self.ln_2(q))) + return q + + +class CustomTransformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + scale_cosine_attn: bool = True, + scale_heads: bool = False, + scale_attn: bool = False, + scale_fc: bool = False, + cross_attn: bool = False, + xattn: bool = False, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + self.xattn = xattn + + self.resblocks = nn.ModuleList( + [ + CustomResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + scale_cosine_attn=scale_cosine_attn, + scale_heads=scale_heads, + scale_attn=scale_attn, + scale_fc=scale_fc, + cross_attn=cross_attn, + xattn=xattn, + ) + for _ in range(layers) + ] + ) + + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + + def forward(self, q: torch.Tensor, k: torch.Tensor = None, v: torch.Tensor = None, attn_mask: Optional[torch.Tensor] = None): + if k is None and v is None: + k = v = q + for r in self.resblocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + q = checkpoint(r, q, k, v, attn_mask) + else: + q = r(q, k, v, attn_mask=attn_mask) + return q + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + xattn: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + if xattn: + self.attn = Attention(d_model, n_head, xattn=True) + else: + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([("c_fc", nn.Linear(d_model, mlp_width)), ("gelu", act_layer()), ("c_proj", nn.Linear(mlp_width, d_model))])) + + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + self.xattn = xattn + + def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None + if self.xattn: + return self.attn(x, attn_mask=attn_mask) + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask)) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class Transformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + xattn: bool = False, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + + self.resblocks = nn.ModuleList([ResidualAttentionBlock(width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn) for _ in range(layers)]) + + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + +class VisionTransformer(nn.Module): + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + ls_init_value: float = None, + patch_dropout: float = 0.0, + global_average_pool: bool = False, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + xattn: bool = False, + ): + super().__init__() + self.image_size = to_2tuple(image_size) + self.patch_size = to_2tuple(patch_size) + self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1]) + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) + + # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity() + self.ln_pre = norm_layer(width) + + self.transformer = Transformer(width, layers, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn) + + self.global_average_pool = global_average_pool + self.ln_post = norm_layer(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + for param in self.parameters(): + param.requires_grad = False + + if unlocked_groups != 0: + groups = [ + [ + self.conv1, + self.class_embedding, + self.positional_embedding, + self.ln_pre, + ], + *self.transformer.resblocks[:-1], + [ + self.transformer.resblocks[-1], + self.ln_post, + ], + self.proj, + ] + + def _unlock(x): + if isinstance(x, Sequence): + for g in x: + _unlock(g) + else: + if isinstance(x, torch.nn.Parameter): + x.requires_grad = True + else: + for p in x.parameters(): + p.requires_grad = True + + _unlock(groups[-unlocked_groups:]) + + def get_num_layers(self): + return self.transformer.layers + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + @torch.jit.ignore + def no_weight_decay(self): + return {"positional_embedding", "class_embedding"} + + def forward(self, x: torch.Tensor, return_all_features: bool = False): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + x = self.patch_dropout(x) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + if not return_all_features: + if self.global_average_pool: + x = x.mean(dim=1) # x = x[:,1:,:].mean(dim=1) + else: + x = x[:, 0] + + x = self.ln_post(x) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class TextTransformer(nn.Module): + def __init__( + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + ls_init_value: float = None, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + xattn: bool = False, + attn_mask: bool = True, + ): + super().__init__() + self.context_length = context_length + self.vocab_size = vocab_size + self.width = width + self.output_dim = output_dim + + self.token_embedding = nn.Embedding(vocab_size, width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width)) + self.transformer = Transformer(width=width, layers=layers, heads=heads, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn) + + self.xattn = xattn + self.ln_final = norm_layer(width) + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + if attn_mask: + self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False) + else: + self.attn_mask = None + + self.init_parameters() + + def init_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + proj_std = (self.transformer.width**-0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + @torch.jit.ignore + def no_weight_decay(self): + # return {'positional_embedding', 'token_embedding'} + return {"positional_embedding"} + + def get_num_layers(self): + return self.transformer.layers + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, text, return_all_features: bool = False): + cast_dtype = self.transformer.get_cast_dtype() + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=self.attn_mask) + # x = self.transformer(x) # no attention mask is applied + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + if not return_all_features: + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + return x diff --git a/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/utils.py b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..73b6d66273fcbaf5b8c19df5e304db533ba88c49 --- /dev/null +++ b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_clip/utils.py @@ -0,0 +1,321 @@ +from itertools import repeat +import collections.abc +import logging +import math +import numpy as np + +import torch +from torch import nn as nn +from torchvision.ops.misc import FrozenBatchNorm2d +import torch.nn.functional as F + + +# open CLIP +def resize_clip_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1): + # Rescale the grid of position embeddings when loading from state_dict + old_pos_embed = state_dict.get("visual.positional_embedding", None) + if old_pos_embed is None or not hasattr(model.visual, "grid_size"): + return + grid_size = to_2tuple(model.visual.grid_size) + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + return + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) + + logging.info("Resizing position embedding grid-size from %s to %s", old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode=interpolation, + align_corners=True, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + state_dict["visual.positional_embedding"] = new_pos_embed + + +def resize_visual_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1): + # Rescale the grid of position embeddings when loading from state_dict + old_pos_embed = state_dict.get("positional_embedding", None) + if old_pos_embed is None or not hasattr(model.visual, "grid_size"): + return + grid_size = to_2tuple(model.visual.grid_size) + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + return + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) + + logging.info("Resizing position embedding grid-size from %s to %s", old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode=interpolation, + align_corners=True, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + state_dict["positional_embedding"] = new_pos_embed + + +def resize_evaclip_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1): + all_keys = list(state_dict.keys()) + # interpolate position embedding + if "visual.pos_embed" in state_dict: + pos_embed_checkpoint = state_dict["visual.pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.visual.patch_embed.num_patches + # num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches + num_extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + state_dict["visual.pos_embed"] = new_pos_embed + + patch_embed_proj = state_dict["visual.patch_embed.proj.weight"] + patch_size = model.visual.patch_embed.patch_size + state_dict["visual.patch_embed.proj.weight"] = torch.nn.functional.interpolate(patch_embed_proj.float(), size=patch_size, mode="bicubic", align_corners=False) + + +def resize_eva_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1): + all_keys = list(state_dict.keys()) + # interpolate position embedding + if "pos_embed" in state_dict: + pos_embed_checkpoint = state_dict["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.visual.patch_embed.num_patches + # num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches + num_extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + state_dict["pos_embed"] = new_pos_embed + + patch_embed_proj = state_dict["patch_embed.proj.weight"] + patch_size = model.visual.patch_embed.patch_size + state_dict["patch_embed.proj.weight"] = torch.nn.functional.interpolate(patch_embed_proj.float(), size=patch_size, mode="bicubic", align_corners=False) + + +def resize_rel_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1): + all_keys = list(state_dict.keys()) + for key in all_keys: + if "relative_position_index" in key: + state_dict.pop(key) + + if "relative_position_bias_table" in key: + rel_pos_bias = state_dict[key] + src_num_pos, num_attn_heads = rel_pos_bias.size() + dst_num_pos, _ = model.visual.state_dict()[key].size() + dst_patch_shape = model.visual.patch_embed.patch_shape + if dst_patch_shape[0] != dst_patch_shape[1]: + raise NotImplementedError() + num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) + src_size = int((src_num_pos - num_extra_tokens) ** 0.5) + dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) + if src_size != dst_size: + print("Position interpolate for %s from %dx%d to %dx%d" % (key, src_size, src_size, dst_size, dst_size)) + extra_tokens = rel_pos_bias[-num_extra_tokens:, :] + rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] + + def geometric_progression(a, r, n): + return a * (1.0 - r**n) / (1.0 - r) + + left, right = 1.01, 1.5 + while right - left > 1e-6: + q = (left + right) / 2.0 + gp = geometric_progression(1, q, src_size // 2) + if gp > dst_size // 2: + right = q + else: + left = q + + # if q > 1.090307: + # q = 1.090307 + + dis = [] + cur = 1 + for i in range(src_size // 2): + dis.append(cur) + cur += q ** (i + 1) + + r_ids = [-_ for _ in reversed(dis)] + + x = r_ids + [0] + dis + y = r_ids + [0] + dis + + t = dst_size // 2.0 + dx = np.arange(-t, t + 0.1, 1.0) + dy = np.arange(-t, t + 0.1, 1.0) + + print("Original positions = %s" % str(x)) + print("Target positions = %s" % str(dx)) + + all_rel_pos_bias = [] + + for i in range(num_attn_heads): + z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() + f = F.interpolate.interp2d(x, y, z, kind="cubic") + all_rel_pos_bias.append(torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device)) + + rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) + + new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) + state_dict[key] = new_rel_pos_bias + + # interpolate position embedding + if "pos_embed" in state_dict: + pos_embed_checkpoint = state_dict["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.visual.patch_embed.num_patches + num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + state_dict["pos_embed"] = new_pos_embed + + patch_embed_proj = state_dict["patch_embed.proj.weight"] + patch_size = model.visual.patch_embed.patch_size + state_dict["patch_embed.proj.weight"] = torch.nn.functional.interpolate(patch_embed_proj.float(), size=patch_size, mode="bicubic", align_corners=False) + + +def freeze_batch_norm_2d(module, module_match={}, name=""): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + module_match (dict): Dictionary of full module names to freeze (all if empty) + name (str): Full module name (prefix) + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + is_match = True + if module_match: + is_match = name in module_match + if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for child_name, child in module.named_children(): + full_child_name = ".".join([name, child_name]) if name else child_name + new_child = freeze_batch_norm_2d(child, module_match, full_child_name) + if new_child is not child: + res.add_module(child_name, new_child) + return res + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = lambda n, x: _ntuple(n)(x) + + +def is_logging(args): + def is_global_master(args): + return args.rank == 0 + + def is_local_master(args): + return args.local_rank == 0 + + def is_master(args, local=False): + return is_local_master(args) if local else is_global_master(args) + + return is_master + + +class AllGather(torch.autograd.Function): + """An autograd function that performs allgather on a tensor. + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + + @staticmethod + def forward(ctx, tensor, rank, world_size): + tensors_gather = [torch.empty_like(tensor) for _ in range(world_size)] + torch.distributed.all_gather(tensors_gather, tensor) + ctx.rank = rank + ctx.batch_size = tensor.shape[0] + return torch.cat(tensors_gather, 0) + + @staticmethod + def backward(ctx, grad_output): + return (grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)], None, None) + + +allgather = AllGather.apply diff --git a/blip3o/model/multimodal_encoder/dev_eva_clip/eva_vit.py b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..d29c7fabaf28c40a5f9868a11c52fa3c67aa15fb --- /dev/null +++ b/blip3o/model/multimodal_encoder/dev_eva_clip/eva_vit.py @@ -0,0 +1,140 @@ +# Based on EVA, BEIT, timm and DeiT code bases +# https://github.com/baaivision/EVA +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/microsoft/unilm/tree/master/beit +# https://github.com/facebookresearch/deit/ +# https://github.com/facebookresearch/dino +# --------------------------------------------------------' +# not tested yet +import math +from transformers import CLIPImageProcessor + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import drop_path, to_2tuple, trunc_normal_ +from .eva_clip import create_model_and_transforms, get_model_config +import torch +import torchvision +import time + + + +class EvaViTWrapper(nn.Module): + def __init__(self, vision_tower, args, delay_load=False): + super().__init__() + + self.is_loaded = False + self.vision_tower_name = vision_tower + self.pretrained = args.vision_tower_pretrained + self.args = args + + self.select_layer = args.mm_vision_select_layer + if self.select_layer < -1: + self.select_layer += 1 + self.select_feature = getattr(args, "mm_vision_select_feature", "patch") + + self.model_config = get_model_config(self.vision_tower_name) + + if not delay_load: + print(f"Loading vision tower: {vision_tower}") + self.load_model() + elif getattr(args, "unfreeze_mm_vision_tower", False): + # TODO: better detector is needed. + print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") + self.load_model() + elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: + print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") + self.load_model() + + def load_model(self): + print(f"Loading: {self.vision_tower_name}") + print(f"Pretrained: {self.pretrained}") + time_start = time.time() + model, _, image_processor = create_model_and_transforms(self.vision_tower_name, self.pretrained, force_custom_clip=True, precision="fp16") + time_end = time.time() + print(f"Loaded: {self.vision_tower_name} in {time_end - time_start:.2f}s") + self.device = next(model.parameters()).device + self.dtype = next(model.parameters()).dtype + if self.device.type != "meta": + model = model.to("cuda") + self.vision_tower = model.visual + resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0] + normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0] + self.resize_transform_size = resize_transform.size + self.image_processor = CLIPImageProcessor.from_pretrained( + "openai/clip-vit-large-patch14", + crop_size=resize_transform.size, + size={"shortest_edge": resize_transform.size}, + image_mean=list(normalize_transform.mean), + image_std=list(normalize_transform.std), + ) + print(f"Loaded image processor: {self.image_processor}") + self.vision_tower.requires_grad_(False) + self.is_loaded = True + + def feature_select(self, image_features): + select_feature_type = self.select_feature + + # if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: + # select_every_k_layer = len(image_features) // 4 + # image_features = torch.cat([image_features[i] for i in range(select_every_k_layer + self.select_layer, len(image_features), select_every_k_layer)], dim=-1) + # select_feature_type = select_feature_type.replace("slicefour_", "") + # elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]: + # select_layers = [-1, -4, -7, -10, 6] + # image_features = torch.cat([image_features[i] for i in select_layers], dim=-1) + # select_feature_type = select_feature_type.replace("slice_m25811_f6_", "") + # else: + # image_features = image_features[self.select_layer] + + if select_feature_type == "patch": + image_features = image_features[:, 1:] + elif select_feature_type == "cls_patch": + image_features = image_features + else: + raise ValueError(f"Unexpected select feature: {select_feature_type}") + return image_features + + def train(self, mode=True): + self.training = mode + + if self.is_loaded: + self.vision_tower.eval() + + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_features = self.vision_tower.forward_features(image.to(self.dtype), return_all_features=True) + image_features = self.feature_select(image_features).to(self.dtype) + image_features.append(image_features) + else: + image_features = self.vision_tower.forward_features(images.to(self.dtype), return_all_features=True) + image_features = self.feature_select(image_features).to(self.dtype) + + return image_features + + @property + def dummy_feature(self): + return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) + + @property + def hidden_size(self): + return self.model_config["vision_cfg"]["width"] + + @property + def num_patches(self): + return (self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]) ** 2 + + @property + def num_patches_per_side(self): + return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"] + + @property + def config(self): + return self.model_config + + @property + def image_size(self): + return self.model_config["vision_cfg"]["image_size"] diff --git a/blip3o/model/multimodal_encoder/eva_clip/eva_clip_encoder.py b/blip3o/model/multimodal_encoder/eva_clip/eva_clip_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..abc2be0255be2d53d78e175528933f135b614dd7 --- /dev/null +++ b/blip3o/model/multimodal_encoder/eva_clip/eva_clip_encoder.py @@ -0,0 +1,75 @@ +import torch +import torch.nn as nn + +from .eva_clip_processors import EvaClipImageTrainProcessor +from .eva_vit import EVAEncoderWrapper +from .factory import list_models, add_model_config, get_model_config + + + +class EvaClipVisionTower(nn.Module): + def __init__(self, vision_tower, args, delay_load=False): + super().__init__() + + self.is_loaded = False + self.vision_tower_name = vision_tower + self.vision_tower_pretrained = args.vision_tower_pretrained + self.config = get_model_config(vision_tower) + + + if not delay_load: + print(f"Loading EVA ViT: {self.vision_tower_name}") + self.load_model() + elif getattr(args, "unfreeze_mm_vision_tower", False): + # TODO: better detector is needed. + print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") + self.load_model() + elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: + print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") + self.load_model() + else: + self.cfg_only = self.config + + + def load_model(self, device_map=None): + print(f"Pretrained: {self.vision_tower_pretrained}") + self.image_processor = EvaClipImageTrainProcessor(self.config["vision_cfg"]["image_size"]) + self.vision_tower = EVAEncoderWrapper(self.vision_tower_pretrained, self.config) + print(f"Loaded image processor: {self.image_processor}") + self.vision_tower.requires_grad_(False) + self.is_loaded = True + + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_feature = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)).to(image.dtype) + image_features.append(image_feature) + else: + image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(images.dtype) + + return image_features + + @property + def dtype(self): + return self.vision_tower.dtype + + @property + def device(self): + return self.vision_tower.device + + @property + def hidden_size(self): + return self.config["vision_cfg"]["width"] + + @property + def num_patches(self): + return (self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"]) ** 2 + + @property + def num_patches_per_side(self): + return self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"] + + @property + def image_size(self): + return self.config["vision_cfg"]["image_size"] diff --git a/blip3o/model/multimodal_encoder/eva_clip/eva_clip_processors.py b/blip3o/model/multimodal_encoder/eva_clip/eva_clip_processors.py new file mode 100644 index 0000000000000000000000000000000000000000..c612b8fc2bf7f8a81f2cbb57da8c50f15005900a --- /dev/null +++ b/blip3o/model/multimodal_encoder/eva_clip/eva_clip_processors.py @@ -0,0 +1,74 @@ +""" +# Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP +""" + +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode +from transformers.image_processing_utils import BatchFeature +from PIL import Image +from transformers.image_transforms import convert_to_rgb +import numpy as np + +class BaseProcessor: + def __init__(self): + self.transform = lambda x: x + return + + def __call__(self, item): + return self.transform(item) + + +class EvaClipImageBaseProcessor(BaseProcessor): + def __init__(self, mean=None, std=None): + self.mean = (0.48145466, 0.4578275, 0.40821073) if mean is None else mean + self.std = (0.26862954, 0.26130258, 0.27577711) if std is None else std + self.normalize = transforms.Normalize(self.mean, self.std) + + + @property + def image_mean(self): + return self.mean + + +class EvaClipImageTrainProcessor(EvaClipImageBaseProcessor): + def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): + super().__init__(mean=mean, std=std) + + self.transform = transforms.Compose( + [ + convert_to_rgb, + transforms.Resize( + image_size, + interpolation=InterpolationMode.BICUBIC, + ), + transforms.CenterCrop(image_size), + transforms.ToTensor(), + self.normalize, + ] + ) + + + self.image_size = image_size + + def preprocess(self, images, return_tensors): + if isinstance(images, Image.Image): + images = [images] + else: + assert isinstance(images, list) + + transformed_images = [self.transform(image).numpy() for image in images] + data = {"pixel_values": transformed_images} + + + return BatchFeature(data=data, tensor_type=return_tensors) + + def __call__(self, item): + return self.transform(item) + + @property + def crop_size(self): + return {"height": self.image_size, "width": self.image_size} + + @property + def size(self): + return {"shortest_edge": self.image_size} diff --git a/blip3o/model/multimodal_encoder/eva_clip/eva_vit.py b/blip3o/model/multimodal_encoder/eva_clip/eva_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..427779a010a2dc5cb6ff2f7fdeb3d0bc58258b9b --- /dev/null +++ b/blip3o/model/multimodal_encoder/eva_clip/eva_vit.py @@ -0,0 +1,762 @@ +""" +# Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP +""" + +from math import pi +import torch +from torch import nn +from einops import rearrange, repeat +import logging + +from huggingface_hub import snapshot_download +cache_dir = snapshot_download(repo_id="jiuhai/eva_clip_vision_tower") + + + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), "invalid dimensions for broadcastable concatentation" + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim=dim) + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +class VisionRotaryEmbeddingFast(nn.Module): + def __init__(self, dim, pt_seq_len, ft_seq_len=None, custom_freqs=None, freqs_for="lang", theta=10000, max_freq=10, num_freqs=1, patch_dropout=0.0): + super().__init__() + if custom_freqs: + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f"unknown modality {freqs_for}") + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs = torch.einsum("..., f -> ... f", t, freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) + + freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) + freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) + + self.patch_dropout = patch_dropout + + self.register_buffer("freqs_cos", freqs_cos) + self.register_buffer("freqs_sin", freqs_sin) + + logging.info(f"Shape of rope freq: {self.freqs_cos.shape}") + + def forward(self, t, patch_indices_keep=None): + if patch_indices_keep is not None: + batch = t.size()[0] + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + freqs_cos = repeat(self.freqs_cos, "i j -> n i m j", n=t.shape[0], m=t.shape[1]) + freqs_sin = repeat(self.freqs_sin, "i j -> n i m j", n=t.shape[0], m=t.shape[1]) + + freqs_cos = freqs_cos[batch_indices, patch_indices_keep] + freqs_cos = rearrange(freqs_cos, "n i m j -> n m i j") + freqs_sin = freqs_sin[batch_indices, patch_indices_keep] + freqs_sin = rearrange(freqs_sin, "n i m j -> n m i j") + + return t * freqs_cos + rotate_half(t) * freqs_sin + + return t * self.freqs_cos + rotate_half(t) * self.freqs_sin + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm (with cast back to input dtype).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + print(f"os.getenv('RoPE')={os.getenv('RoPE')}") + + def forward(self, x): + if not self.training or self.prob == 0.: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + if self.training and os.getenv('RoPE') == '1': + return x, patch_indices_keep + + return x + + +# -------------------------------------------------------- +# Adapted from https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- +import math +import os +import torch.nn as nn +import torch.nn.functional as F + +try: + from timm.models.layers import drop_path, to_2tuple, trunc_normal_ +except: + from timm.layers import drop_path, to_2tuple, trunc_normal_ + +if os.getenv("ENV_TYPE") == "deepspeed": + try: + from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint + except: + from torch.utils.checkpoint import checkpoint +else: + from torch.utils.checkpoint import checkpoint + +try: + import xformers.ops as xops +except ImportError: + xops = None + # print("Please 'pip install xformers'") + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return 'p={}'.format(self.drop_prob) + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + drop=0., + subln=False, + + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + + self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity() + + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + # x = self.drop(x) + # commit this for the orignal BERT implement + x = self.ffn_ln(x) + + x = self.fc2(x) + x = self.drop(x) + return x + + +class SwiGLU(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0., + norm_layer=nn.LayerNorm, subln=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.w1 = nn.Linear(in_features, hidden_features) + self.w2 = nn.Linear(in_features, hidden_features) + + self.act = act_layer() + self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity() + self.w3 = nn.Linear(hidden_features, out_features) + + self.drop = nn.Dropout(drop) + + def forward(self, x): + x1 = self.w1(x) + x2 = self.w2(x) + hidden = self.act(x1) * x2 + x = self.ffn_ln(hidden) + x = self.w3(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.subln = subln + if self.subln: + self.q_proj = nn.Linear(dim, all_head_dim, bias=False) + self.k_proj = nn.Linear(dim, all_head_dim, bias=False) + self.v_proj = nn.Linear(dim, all_head_dim, bias=False) + else: + if qkv_bias: + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=True) + else: + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + + # if qkv_bias: + # self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + # self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + # else: + # self.q_bias = None + # self.v_bias = None + + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + self.attn_drop = nn.Dropout(attn_drop) + self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity() + # self.proj = nn.Linear(all_head_dim, all_head_dim) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.xattn = xattn + self.xattn_drop = attn_drop + + self.rope = rope + + def forward(self, x, rel_pos_bias=None, attn_mask=None): + B, N, C = x.shape + if self.subln: + q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias) + k = F.linear(input=x, weight=self.k_proj.weight, bias=None) + v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias) + + q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C + k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) + v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) + else: + + # qkv_bias = None + # if self.q_bias is not None: + # qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + + # qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + + qkv = self.qkv(x) + + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C + q, k, v = qkv[0], qkv[1], qkv[2] + + if self.rope: + q_t = q[:, :, 1:, :] + ro_q_t = self.rope(q_t) + q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v) + + k_t = k[:, :, 1:, :] + ro_k_t = self.rope(k_t) + k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v) + + if self.xattn: + q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + x = xops.memory_efficient_attention( + q, k, v, + p=self.xattn_drop, + scale=self.scale, + ) + x = x.reshape(B, N, -1) + x = self.inner_attn_ln(x) + x = self.proj(x) + x = self.proj_drop(x) + else: + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0).type_as(attn) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias.type_as(attn) + + if attn_mask is not None: + attn_mask = attn_mask.bool() + attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf")) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.inner_attn_ln(x) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, + window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False, + subln=False, naiveswiglu=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim, + xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + if naiveswiglu: + self.mlp = SwiGLU( + in_features=dim, + hidden_features=mlp_hidden_dim, + subln=subln, + norm_layer=norm_layer, + ) + else: + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + subln=subln, + drop=drop + ) + + if init_values is not None and init_values > 0: + self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + self.postnorm = postnorm + + def forward(self, x, rel_pos_bias=None, attn_mask=None): + if self.gamma_1 is None: + if self.postnorm: + x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))) + x = x + self.drop_path(self.norm2(self.mlp(x))) + else: + x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + if self.postnorm: + x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))) + x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class RelativePositionBias(nn.Module): + + def __init__(self, window_size, num_heads): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + + def forward(self): + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +class EVAVisionTransformer(nn.Module): + """Vision Transformer with support for patch or hybrid CNN input stage""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0., + use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False, + use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False, + pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False, + ): + super().__init__() + self.image_size = img_size + # self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + if use_abs_pos_emb: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + else: + self.pos_embed = None + self.pos_drop = nn.Dropout(p=drop_rate) + + self.rel_pos_bias = None + self.rope = None + + self.naiveswiglu = naiveswiglu + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.use_rel_pos_bias = use_rel_pos_bias + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None, + xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu) + for i in range(depth)]) + + # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() + + self.grad_checkpointing = grad_checkpointing + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + if self.naiveswiglu: + rescale(layer.mlp.w3.weight.data, layer_id + 1) + else: + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def get_cast_dtype(self) -> torch.dtype: + return self.blocks[0].mlp.fc2.weight.dtype + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_num_layers(self): + return len(self.blocks) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert unlocked_groups == 0, "partial locking not currently supported for this model" + for param in self.parameters(): + param.requires_grad = False + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed", "cls_token"} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=""): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + if os.getenv('RoPE') == '1': + if self.training and not isinstance(self.patch_dropout, nn.Identity): + x, patch_indices_keep = self.patch_dropout(x) + self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep) + else: + self.rope.forward = partial(self.rope.forward, patch_indices_keep=None) + x = self.patch_dropout(x) + else: + x = self.patch_dropout(x) + + rel_pos_bias = None + + for blk in self.blocks: + if self.grad_checkpointing: + x = checkpoint(blk, x, (rel_pos_bias,)) + else: + x = blk(x, rel_pos_bias=rel_pos_bias) + + return x + + def forward(self, x, return_all_features=False): + + """ + :return: + forward_features function returns raw features of ViT, + forward with return_all_features returns normalized features of ViT + :param x: + :param return_all_features: + """ + + features = self.forward_features(x) # [B, n_patch, C] + return features + + +def load_state_dict(checkpoint_path: str, map_location: str = "cpu", model_key: str = "model|module|state_dict", is_openai: bool = False, skip_list: list = []): + if is_openai: + model = torch.jit.load(checkpoint_path, map_location="cpu").eval() + state_dict = model.state_dict() + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + else: + checkpoint = torch.load(checkpoint_path, map_location=map_location) + for mk in model_key.split("|"): + if isinstance(checkpoint, dict) and mk in checkpoint: + state_dict = checkpoint[mk] + break + else: + state_dict = checkpoint + if next(iter(state_dict.items()))[0].startswith("module"): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + for k in skip_list: + if k in list(state_dict.keys()): + logging.info(f"Removing key {k} from pretrained checkpoint") + del state_dict[k] + + if os.getenv("RoPE") == "1": + for k in list(state_dict.keys()): + if "freqs_cos" in k or "freqs_sin" in k: + del state_dict[k] + return state_dict + + +def load_clip_visual_state_dict(checkpoint_path: str, map_location: str = "cpu", is_openai: bool = False, skip_list: list = []): + state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list) + # for k in list(state_dict.keys()): + # if not k.startswith("visual."): + # del state_dict[k] + # for k in list(state_dict.keys()): + # if k.startswith("visual."): + # new_k = k[7:] + # state_dict[new_k] = state_dict[k] + # del state_dict[k] + return state_dict + + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +try: + from apex.normalization import FusedLayerNorm +except: + FusedLayerNorm = LayerNorm + # print("Please build and install Nvidia apex package with option '--cuda_ext' according to https://github.com/NVIDIA/apex#from-source .") + + +@dataclass +class CLIPVisionCfg: + layers: Union[Tuple[int, int, int, int], int] = 12 + width: int = 768 + head_width: int = 64 + mlp_ratio: float = 4.0 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + ls_init_value: Optional[float] = None # layer scale initial value + patch_dropout: float = 0.0 # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results + global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) + drop_path_rate: Optional[float] = None # drop path rate + timm_model_name: str = None # a valid model name overrides layers, width, patch_size + timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model + timm_pool: str = "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + timm_proj: str = "linear" # linear projection for timm model output ('linear', 'mlp', '') + timm_proj_bias: bool = False # enable bias final projection + eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size + qkv_bias: bool = True + fusedLN: bool = False + xattn: bool = False + postnorm: bool = False + rope: bool = False + pt_hw_seq_len: int = 16 # 224/14 + intp_freq: bool = False + naiveswiglu: bool = False + subln: bool = False + + +def create_norm_layer_factory(use_fused_ln, eps=1e-6): + # Otherwise, use the standard LayerNorm + return lambda num_features: nn.LayerNorm(num_features, eps=eps) + + +def _build_vision_tower(vision_tower_path: str, embed_dim: int, vision_cfg: CLIPVisionCfg, **kwargs): + if isinstance(vision_cfg, dict): + vision_cfg = CLIPVisionCfg(**vision_cfg) + + if vision_cfg.eva_model_name: + vision_heads = vision_cfg.width // vision_cfg.head_width + # Determine the appropriate norm layer factory based on the configuration + norm_layer_factory = create_norm_layer_factory(vision_cfg.fusedLN, eps=1e-6) + + # breakpoint() + visual = EVAVisionTransformer( + img_size=vision_cfg.image_size, + patch_size=vision_cfg.patch_size, + num_classes=embed_dim, + use_mean_pooling=vision_cfg.global_average_pool, # False + init_values=vision_cfg.ls_init_value, + patch_dropout=vision_cfg.patch_dropout, + embed_dim=vision_cfg.width, + depth=vision_cfg.layers, + num_heads=vision_heads, + mlp_ratio=vision_cfg.mlp_ratio, + qkv_bias=vision_cfg.qkv_bias, + drop_path_rate=vision_cfg.drop_path_rate, + norm_layer=norm_layer_factory, + xattn=vision_cfg.xattn, + rope=vision_cfg.rope, + postnorm=vision_cfg.postnorm, + pt_hw_seq_len=vision_cfg.pt_hw_seq_len, # 224/14 + intp_freq=vision_cfg.intp_freq, + naiveswiglu=vision_cfg.naiveswiglu, + subln=vision_cfg.subln, + ) + # breakpoint() + state_dict = load_clip_visual_state_dict(vision_tower_path) + incompatible_keys = visual.load_state_dict(state_dict, strict=False) + print("EVA-CLIP incompatible_keys:", incompatible_keys) + + return visual + + +class EVAEncoderWrapper(nn.Module): + def __init__(self, vision_tower_pretrained, config): + super(EVAEncoderWrapper, self).__init__() + self.config = config + self.config["vision_tower_path"] = os.path.join(cache_dir, "pytorch_model.bin") + self.model = _build_vision_tower(**self.config) + + def forward(self, image, **kwargs): + encode = self.model(image, return_all_features=True)[:, 1:, :] # remove the CLS token + return encode + + @property + def dtype(self): + return list(self.parameters())[-1].dtype + + @property + def device(self): + return list(self.parameters())[-1].device diff --git a/blip3o/model/multimodal_encoder/eva_clip/factory.py b/blip3o/model/multimodal_encoder/eva_clip/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..f40ee2066c767dd0d8491173186633ab7ef05ee1 --- /dev/null +++ b/blip3o/model/multimodal_encoder/eva_clip/factory.py @@ -0,0 +1,59 @@ +import json +import logging +import os +import pathlib +import re +from copy import deepcopy +from pathlib import Path +from typing import Optional, Tuple, Union, Dict, Any +import torch + +_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] +_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] + + +def _rescan_model_configs(): + global _MODEL_CONFIGS + + config_ext = (".json",) + config_files = [] + for config_path in _MODEL_CONFIG_PATHS: + if config_path.is_file() and config_path.suffix in config_ext: + config_files.append(config_path) + elif config_path.is_dir(): + for ext in config_ext: + config_files.extend(config_path.glob(f"*{ext}")) + for cf in config_files: + with open(cf, "r", encoding="utf8") as f: + model_cfg = json.load(f) + if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")): + _MODEL_CONFIGS[cf.stem] = model_cfg + + _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))) + + +_rescan_model_configs() # initial populate of model config registry + + +def list_models(): + """enumerate available model architectures based on config files""" + return list(_MODEL_CONFIGS.keys()) + + +def add_model_config(path): + """add model config path or file and update registry""" + if not isinstance(path, Path): + path = Path(path) + _MODEL_CONFIG_PATHS.append(path) + _rescan_model_configs() + + +def get_model_config(model_name): + if model_name in _MODEL_CONFIGS: + return deepcopy(_MODEL_CONFIGS[model_name]) + else: + return None diff --git a/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-18B.json b/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-18B.json new file mode 100644 index 0000000000000000000000000000000000000000..4917556693fe9dcbddeadf7459be363740d55aa5 --- /dev/null +++ b/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-18B.json @@ -0,0 +1,27 @@ +{ + "embed_dim": 1536, + "vision_cfg": { + "image_size": 224, + "layers": 48, + "width": 5120, + "head_width": 128, + "mlp_ratio": 5, + "patch_size": 14, + "eva_model_name": "eva-clip-18b-14-x", + "drop_path_rate": 0, + "qkv_bias": false, + "xattn": true, + "postnorm": true, + "fusedLN": false, + "use_rms_norm": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32, + "xattn": false, + "fusedLN": false + } +} \ No newline at end of file diff --git a/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B-plus.json b/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..7d843daa36c36b6100291fa3a2cb58672db6bbfb --- /dev/null +++ b/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B-plus.json @@ -0,0 +1,27 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 448, + "layers": 32, + "width": 4096, + "head_width": 128, + "mlp_ratio": 5, + "patch_size": 14, + "eva_model_name": "eva-clip-8b-14-plus-x", + "drop_path_rate": 0, + "qkv_bias": false, + "xattn": true, + "postnorm": false, + "fusedLN": false, + "use_rms_norm": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32, + "xattn": false, + "fusedLN": false + } +} \ No newline at end of file diff --git a/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B.json b/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B.json new file mode 100644 index 0000000000000000000000000000000000000000..689492a25d365436fd85ed432e6fb7295ca1c7bd --- /dev/null +++ b/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B.json @@ -0,0 +1,27 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 4096, + "head_width": 128, + "mlp_ratio": 5, + "patch_size": 14, + "eva_model_name": "eva-clip-8b-14-x", + "drop_path_rate": 0, + "qkv_bias": false, + "xattn": true, + "postnorm": false, + "fusedLN": false, + "use_rms_norm": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32, + "xattn": false, + "fusedLN": false + } +} \ No newline at end of file diff --git a/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-B-16.json b/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-B-16.json new file mode 100644 index 0000000000000000000000000000000000000000..aad2058003962a4ab286bf4e1ae956288af34e62 --- /dev/null +++ b/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-B-16.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 16, + "eva_model_name": "eva-clip-b-16", + "ls_init_value": 0.1, + "drop_path_rate": 0.0 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json b/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..100279572ff6d1bcca601f0eb526b4d4ff174c7d --- /dev/null +++ b/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 40, + "width": 1408, + "head_width": 88, + "mlp_ratio": 4.3637, + "patch_size": 14, + "eva_model_name": "eva-clip-g-14-x", + "drop_path_rate": 0, + "xattn": true, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14.json b/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14.json new file mode 100644 index 0000000000000000000000000000000000000000..5d338b4e6104241d1f0304ee82400035d5385332 --- /dev/null +++ b/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 40, + "width": 1408, + "head_width": 88, + "mlp_ratio": 4.3637, + "patch_size": 14, + "eva_model_name": "eva-clip-g-14-x", + "drop_path_rate": 0.4, + "xattn": true, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-B-16.json b/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-B-16.json new file mode 100644 index 0000000000000000000000000000000000000000..e4a6e723f77033caa341ddf9b5be1787d64ad42c --- /dev/null +++ b/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-B-16.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "head_width": 64, + "patch_size": 16, + "mlp_ratio": 2.6667, + "eva_model_name": "eva-clip-b-16-X", + "drop_path_rate": 0.0, + "xattn": true, + "fusedLN": true, + "rope": true, + "pt_hw_seq_len": 16, + "intp_freq": true, + "naiveswiglu": true, + "subln": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "xattn": true, + "fusedLN": true + } +} \ No newline at end of file diff --git a/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14-336.json b/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14-336.json new file mode 100644 index 0000000000000000000000000000000000000000..3e1d124e1118911c5ad7b1ce85df195aca363ac4 --- /dev/null +++ b/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14-336.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 336, + "layers": 24, + "width": 1024, + "drop_path_rate": 0, + "head_width": 64, + "mlp_ratio": 2.6667, + "patch_size": 14, + "eva_model_name": "eva-clip-l-14-336", + "xattn": true, + "fusedLN": true, + "rope": true, + "pt_hw_seq_len": 16, + "intp_freq": true, + "naiveswiglu": true, + "subln": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14.json b/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..03b22ad3cfb92f9c843b9ec8d672e57e7a9ba4a2 --- /dev/null +++ b/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "drop_path_rate": 0, + "head_width": 64, + "mlp_ratio": 2.6667, + "patch_size": 14, + "eva_model_name": "eva-clip-l-14", + "xattn": true, + "fusedLN": true, + "rope": true, + "pt_hw_seq_len": 16, + "intp_freq": true, + "naiveswiglu": true, + "subln": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json b/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..0ec28217e945a9a26feee61d67f18262e3e5804e --- /dev/null +++ b/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json @@ -0,0 +1,28 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 64, + "width": 1792, + "head_width": 112, + "mlp_ratio": 8.571428571428571, + "patch_size": 14, + "eva_model_name": "eva-clip-4b-14-x", + "drop_path_rate": 0, + "xattn": true, + "postnorm": true, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32, + "xattn": false, + "fusedLN": true + } +} + + + diff --git a/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14.json b/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14.json new file mode 100644 index 0000000000000000000000000000000000000000..747ffccc8bd49dbb6701b58e15843b7fe3754e64 --- /dev/null +++ b/blip3o/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14.json @@ -0,0 +1,25 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 64, + "width": 1792, + "head_width": 112, + "mlp_ratio": 8.571428571428571, + "patch_size": 14, + "eva_model_name": "eva-clip-4b-14-x", + "drop_path_rate": 0, + "xattn": true, + "postnorm": true, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/blip3o/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14-448.json b/blip3o/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14-448.json new file mode 100644 index 0000000000000000000000000000000000000000..ad71aff86a4d3b0e34c0bb55ea2b8e3ac220477a --- /dev/null +++ b/blip3o/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14-448.json @@ -0,0 +1,25 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 448, + "layers": 77, + "width": 2304, + "head_width": 144, + "mlp_ratio": 10.9722, + "patch_size": 14, + "eva_model_name": "eva-clip-10b-14-x", + "drop_path_rate": 0, + "xattn": true, + "postnorm": false, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32, + "xattn": false, + "fusedLN": true + } +} diff --git a/blip3o/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14.json b/blip3o/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14.json new file mode 100644 index 0000000000000000000000000000000000000000..21b20680715e6e719b6c8d51c86701a4811ac37f --- /dev/null +++ b/blip3o/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14.json @@ -0,0 +1,25 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 77, + "width": 2304, + "head_width": 144, + "mlp_ratio": 10.9722, + "patch_size": 14, + "eva_model_name": "eva-clip-10b-14-x", + "drop_path_rate": 0, + "xattn": true, + "postnorm": false, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32, + "xattn": false, + "fusedLN": true + } +} diff --git a/blip3o/model/multimodal_encoder/eva_clip/model_configs/eva-clip-E-14-plus.json b/blip3o/model/multimodal_encoder/eva_clip/model_configs/eva-clip-E-14-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..8968bb056e571b80d940a338ed0ed93bafc19492 --- /dev/null +++ b/blip3o/model/multimodal_encoder/eva_clip/model_configs/eva-clip-E-14-plus.json @@ -0,0 +1,27 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "drop_path_rate": 0, + "eva_model_name": "eva-clip-E-14-plus", + "head_width": 112, + "image_size": 448, + "layers": 64, + "mlp_ratio": 8.571428571428571, + "patch_size": 14, + "postnorm": true, + "qkv_bias": true, + "width": 1792, + "xattn": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32, + "xattn": false, + "fusedLN": true + } +} + + diff --git a/blip3o/model/multimodal_encoder/eva_clip/visual.py b/blip3o/model/multimodal_encoder/eva_clip/visual.py new file mode 100644 index 0000000000000000000000000000000000000000..4a9c9f5f5c0891fb7e1b145825231ed94ac1b9f1 --- /dev/null +++ b/blip3o/model/multimodal_encoder/eva_clip/visual.py @@ -0,0 +1,450 @@ +# -------------------------------------------------------- +# Adapted from https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import os +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint + +try: + from timm.models.layers import drop_path, to_2tuple +except: + from timm.layers import drop_path, to_2tuple + +try: + import xformers.ops as xops +except ImportError: + xops = None + print("Please 'pip install xformers'") + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + print(f"os.getenv('RoPE')={os.getenv('RoPE')}") + + def forward(self, x): + if not self.training or self.prob == 0.: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + if self.training and os.getenv('RoPE') == '1': + return x, patch_indices_keep + + return x + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return 'p={}'.format(self.drop_prob) + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + drop=0., + subln=False, + + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + + self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity() + + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + # x = self.drop(x) + # commit this for the orignal BERT implement + x = self.ffn_ln(x) + + x = self.fc2(x) + x = self.drop(x) + return x + +class SwiGLU(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0., + norm_layer=nn.LayerNorm, subln=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.w1 = nn.Linear(in_features, hidden_features) + self.w2 = nn.Linear(in_features, hidden_features) + + self.act = act_layer() + self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity() + self.w3 = nn.Linear(hidden_features, out_features) + + self.drop = nn.Dropout(drop) + + def forward(self, x): + x1 = self.w1(x) + x2 = self.w2(x) + hidden = self.act(x1) * x2 + x = self.ffn_ln(hidden) + x = self.w3(x) + x = self.drop(x) + return x + +class Attention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.subln = subln + if self.subln: + self.q_proj = nn.Linear(dim, all_head_dim, bias=False) + self.k_proj = nn.Linear(dim, all_head_dim, bias=False) + self.v_proj = nn.Linear(dim, all_head_dim, bias=False) + else: + if qkv_bias: + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=True) + else: + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + + # if qkv_bias: + # self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + # self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + # else: + # self.q_bias = None + # self.v_bias = None + + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + self.attn_drop = nn.Dropout(attn_drop) + self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity() + # self.proj = nn.Linear(all_head_dim, all_head_dim) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.xattn = xattn + self.xattn_drop = attn_drop + + self.rope = rope + + def forward(self, x, rel_pos_bias=None, attn_mask=None): + B, N, C = x.shape + if self.subln: + q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias) + k = F.linear(input=x, weight=self.k_proj.weight, bias=None) + v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias) + + q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C + k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) + v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) + else: + + # qkv_bias = None + # if self.q_bias is not None: + # qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + + # qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + + qkv = self.qkv(x) + + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C + q, k, v = qkv[0], qkv[1], qkv[2] + + if self.rope: + q_t = q[:, :, 1:, :] + ro_q_t = self.rope(q_t) + q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v) + + k_t = k[:, :, 1:, :] + ro_k_t = self.rope(k_t) + k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v) + + if self.xattn: + q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + x = xops.memory_efficient_attention( + q, k, v, + p=self.xattn_drop, + scale=self.scale, + ) + x = x.reshape(B, N, -1) + x = self.inner_attn_ln(x) + x = self.proj(x) + x = self.proj_drop(x) + else: + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0).type_as(attn) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias.type_as(attn) + + if attn_mask is not None: + attn_mask = attn_mask.bool() + attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf")) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.inner_attn_ln(x) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, + window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False, + subln=False, naiveswiglu=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim, + xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + if naiveswiglu: + self.mlp = SwiGLU( + in_features=dim, + hidden_features=mlp_hidden_dim, + subln=subln, + norm_layer=norm_layer, + ) + else: + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + subln=subln, + drop=drop + ) + + if init_values is not None and init_values > 0: + self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + self.postnorm = postnorm + + def forward(self, x, rel_pos_bias=None, attn_mask=None): + if self.gamma_1 is None: + if self.postnorm: + x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))) + x = x + self.drop_path(self.norm2(self.mlp(x))) + else: + x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + if self.postnorm: + x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))) + x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class EVAVisionTransformer(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0., + use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False, + use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False, + pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False, + ): + super().__init__() + self.image_size = img_size + # self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + if use_abs_pos_emb: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + else: + self.pos_embed = None + self.pos_drop = nn.Dropout(p=drop_rate) + + self.rel_pos_bias = None + self.rope = None + + self.naiveswiglu = naiveswiglu + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.use_rel_pos_bias = use_rel_pos_bias + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None, + xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu) + for i in range(depth)]) + + # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() + + self.grad_checkpointing = grad_checkpointing + + + def get_num_layers(self): + return len(self.blocks) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert unlocked_groups == 0, 'partial locking not currently supported for this model' + for param in self.parameters(): + param.requires_grad = False + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + + def forward_features(self, x): + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + if os.getenv('RoPE') == '1': + if self.training and not isinstance(self.patch_dropout, nn.Identity): + x, patch_indices_keep = self.patch_dropout(x) + self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep) + else: + self.rope.forward = partial(self.rope.forward, patch_indices_keep=None) + x = self.patch_dropout(x) + else: + x = self.patch_dropout(x) + + rel_pos_bias = None + + for blk in self.blocks: + if self.grad_checkpointing: + x = checkpoint(blk, x, (rel_pos_bias,)) + else: + x = blk(x, rel_pos_bias=rel_pos_bias) + + return x + + def forward(self, x): + + """ + :return: + forward_features function returns raw features of ViT, + forward with return_all_features returns normalized features of ViT + :param x: + :param return_all_features: + """ + + features = self.forward_features(x) # [B, n_patch, C] + + return features \ No newline at end of file diff --git a/blip3o/model/multimodal_encoder/imagebind.py b/blip3o/model/multimodal_encoder/imagebind.py new file mode 100644 index 0000000000000000000000000000000000000000..8bbe71c7b42e25ff7e5a8912b403498002a26348 --- /dev/null +++ b/blip3o/model/multimodal_encoder/imagebind.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn + +from transformers import CLIPImageProcessor + +try: + from imagebind.models import imagebind_model + from imagebind.models.imagebind_model import ModalityType + from imagebind.data import load_and_transform_audio_data +except ImportError: + pass + + +class ImageBindWrapper(nn.Module): + def __init__(self, vision_tower, select_layer, select_feature="patch", delay_load=False): + super().__init__() + + self.is_loaded = False + + self.vision_tower_name = vision_tower + self.select_layer = select_layer + self.select_feature = select_feature + + if not delay_load: + self.load_model() + + def load_model(self): + self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") + self.vision_tower = imagebind_model.imagebind_huge(pretrained=True) + for p in self.vision_tower.parameters(): + p.requires_grad = False + self.vision_tower.eval() + self.is_loaded = True + + def train(self, mode=True): + self.training = mode + + if self.is_loaded: + self.vision_tower.eval() + + @torch.no_grad() + def forward(self, x): + if type(x) == dict: + if x["audios"] is not None: + inputs = {ModalityType.AUDIO: load_and_transform_audio_data(x["audios"], device=self.device).half()} + embeddings = self.vision_tower(inputs) + audio_embedding = embeddings[ModalityType.AUDIO] + return audio_embedding.unsqueeze(1) + else: + inputs = {ModalityType.VISION: x.to(dtype=self.dtype)} + embeddings = self.vision_tower(inputs) + vision_embedding = embeddings[ModalityType.VISION] + if vision_embedding.ndim == 2: + return vision_embedding.unsqueeze(1) + if vision_embedding.shape[1] == 257: + return vision_embedding[:, 1:] + raise ValueError(f"Unexpected shape: {vision_embedding.shape}") + + @property + def dummy_feature(self): + return torch.zeros(1, 1024, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + return self.vision_tower.modality_preprocessors.vision.cls_token.dtype + + @property + def device(self): + return self.vision_tower.modality_preprocessors.vision.cls_token.device + + @property + def hidden_size(self): + return 1024 diff --git a/blip3o/model/multimodal_encoder/open_clip_encoder.py b/blip3o/model/multimodal_encoder/open_clip_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b88d618c3c13fd416771e67826d4d621b896575f --- /dev/null +++ b/blip3o/model/multimodal_encoder/open_clip_encoder.py @@ -0,0 +1,162 @@ +import torch +import torch.nn as nn +from transformers import CLIPImageProcessor + +try: + import open_clip + import torchvision + from open_clip.transformer import _expand_token +except ImportError: + print("OpenCLIP not installed") + open_clip = None + +HIDDEN_SIZE_DICT = { + "ViT-H-14-378-quickgelu": 1280, +} + + +class OpenCLIPVisionTower(nn.Module): + def __init__(self, vision_tower, args, delay_load=False): + super().__init__() + + self.is_loaded = False + self.model_name = vision_tower.replace("open_clip_hub:", "") + self.pretrained = args.vision_tower_pretrained + self.select_layer = args.mm_vision_select_layer + self.select_feature = getattr(args, "mm_vision_select_feature", "patch") + + if not delay_load: + print(f"Loading vision tower: {vision_tower}") + self.load_model() + elif getattr(args, "unfreeze_mm_vision_tower", False): + # TODO: better detector is needed. + print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") + self.load_model() + elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: + print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") + self.load_model() + + def load_model(self, device_map="auto"): + print(f"Loading OpenCLIP model: {self.model_name}") + print(f"Pretrained: {self.pretrained}") + vision_tower, _, image_processor = open_clip.create_model_and_transforms(model_name=self.model_name, pretrained=self.pretrained, precision="fp32", device="cuda") + + resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0] + normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0] + self.resize_transform_size = resize_transform.size # 224 or 384 + self.patch_size = vision_tower.visual.conv1.kernel_size[0] # 14 or 16 + + self.image_processor = CLIPImageProcessor.from_pretrained( + "openai/clip-vit-large-patch14", + crop_size=resize_transform.size, + size={"shortest_edge": resize_transform.size}, + image_mean=list(normalize_transform.mean), + image_std=list(normalize_transform.std), + ) + print(f"Loaded image processor: {self.image_processor}") + self.vision_tower = vision_tower.visual + self.vision_tower.requires_grad_(False) + + self.is_loaded = True + + def feature_select(self, image_forward_outs): + image_features = image_forward_outs[self.select_layer] + if self.select_feature == "patch": + image_features = image_features[:, 1:] + elif self.select_feature == "cls_patch": + image_features = image_features + elif self.select_feature == "conv_flatten": + image_features = image_features.flatten(2).transpose(1, 2) + else: + raise ValueError(f"Unexpected select feature: {self.select_feature}") + return image_features + + def forward_visual(self, x, output_hidden_states=False): + if hasattr(self.vision_tower, "trunk") and hasattr(self.vision_tower.trunk, "_intermediate_layers"): + return self.vision_tower.trunk._intermediate_layers(x, abs(self.select_layer)) + else: + + def forward_openclip(self, x: torch.Tensor): + features = [] + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + # class embeddings and positional embeddings + x = torch.cat( + [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], + dim=1, + ) + # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + + x = self.patch_dropout(x) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + for r in self.transformer.resblocks: + x = r(x, attn_mask=None) + features.append(x) + return features + + return forward_openclip(self.vision_tower, x) + + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_forward_out = self.forward_visual(image.to(self.dtype).unsqueeze(0), output_hidden_states=True) + image_feature = self.feature_select(image_forward_out).to(image.dtype) + image_features.append(image_feature) + else: + image_forward_outs = self.forward_visual(images.to(self.dtype), output_hidden_states=True) + image_features = self.feature_select(image_forward_outs).to(images.dtype) + + return image_features + + @property + def dummy_feature(self): + return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + if hasattr(self.vision_tower, "conv1"): + return self.vision_tower.conv1.weight.dtype + if hasattr(self.vision_tower, "trunk"): + return self.vision_tower.trunk.patch_embed.proj.weight.dtype + raise NotImplementedError + + @property + def device(self): + if hasattr(self.vision_tower, "conv1"): + return self.vision_tower.conv1.weight.device + if hasattr(self.vision_tower, "trunk"): + return self.vision_tower.trunk.patch_embed.proj.weight.device + raise NotImplementedError + + @property + def config(self): + return None + + @property + def hidden_size(self): + if self.model_name in HIDDEN_SIZE_DICT: + return HIDDEN_SIZE_DICT[self.model_name] + else: + raise NotImplementedError + + @property + def num_patches(self): + image_size = self.resize_transform_size if isinstance(self.resize_transform_size, int) else self.resize_transform_size[0] + _num_patches = (image_size // self.patch_size) ** 2 + if "cls_patch" in self.select_feature: + _num_patches += 1 + return _num_patches + + @property + def image_size(self): + return self.resize_transform_size + + @property + def num_patches_per_side(self): + return self.resize_transform_size // self.patch_size diff --git a/blip3o/model/multimodal_encoder/siglip_encoder.py b/blip3o/model/multimodal_encoder/siglip_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f8654a1a4b80d8ab5f09fb9fe514c80e356b88e0 --- /dev/null +++ b/blip3o/model/multimodal_encoder/siglip_encoder.py @@ -0,0 +1,637 @@ +from typing import Optional, Tuple, Union, Dict +from dataclasses import dataclass +from functools import partial, reduce +from PIL import Image +import torch +import torch.utils.checkpoint +from torch import nn +import os +from transformers.image_processing_utils import BatchFeature, get_size_dict +from transformers.image_transforms import ( + convert_to_rgb, + normalize, + rescale, + resize, + to_channel_dimension_format, +) +from transformers.image_utils import ( + ChannelDimension, + PILImageResampling, + to_numpy_array, +) +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from transformers.modeling_utils import PreTrainedModel +from transformers import PretrainedConfig +from transformers.utils import ModelOutput +import numpy as np + + +class SigLipImageProcessor: + def __init__(self, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), size=(384, 384), crop_size: Dict[str, int] = None, resample=PILImageResampling.BICUBIC, rescale_factor=1 / 255, data_format=ChannelDimension.FIRST): + crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384} + crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") + + self.image_mean = image_mean + self.image_std = image_std + self.size = size + self.resample = resample + self.rescale_factor = rescale_factor + self.data_format = data_format + self.crop_size = crop_size + + def preprocess(self, images, return_tensors): + if isinstance(images, Image.Image): + images = [images] + else: + # to adapt video data + images = [to_numpy_array(image) for image in images] + assert isinstance(images, list) + + + try: + transforms = [ + convert_to_rgb, + to_numpy_array, + partial(resize, size=self.size, resample=self.resample, data_format=self.data_format), + partial(rescale, scale=self.rescale_factor, data_format=self.data_format), + partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format), + partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format), + ] + + images = reduce(lambda x, f: [*map(f, x)], transforms, images) + data = {"pixel_values": images} + except ValueError as e: + try: + transforms = [ + convert_to_rgb, + to_numpy_array, + partial(resize, size=self.size, resample=self.resample, data_format=self.data_format), + partial(rescale, scale=self.rescale_factor, data_format=self.data_format), + partial(normalize, mean=self.image_mean[0], std=self.image_std[0], data_format=self.data_format), + partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format), + ] + images = reduce(lambda x, f: [*map(f, x)], transforms, images) + processed_images = [np.repeat(img, repeats=3, axis=0) for img in images] + data = {"pixel_values": processed_images} + except ValueError as e: + print(f"Grayscale processing failed: {e}") + + + + return BatchFeature(data=data, tensor_type=return_tensors) + + +class SigLipVisionConfig(PretrainedConfig): + model_type = "siglip_vision_model" + + def __init__( + self, + hidden_size=1152, + image_mean=(0.5, 0.5, 0.5), + intermediate_size=4304, + num_hidden_layers=27, + num_attention_heads=16, + num_channels=3, + image_size=384, + patch_size=14, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.image_mean = image_mean + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from SigLipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + print(f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors.") + + return cls.from_dict(config_dict, **kwargs) + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->SigLip +class SigLipVisionModelOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class SigLipVisionEmbeddings(nn.Module): + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +class SigLipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError(f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError(f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" f" {attn_weights.size()}") + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError(f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}") + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): + raise ValueError(f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->SigLip +class SigLipMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->SigLip +class SigLipEncoderLayer(nn.Module): + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = SigLipAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SigLipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + # Ignore copy + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SigLipPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SigLipVisionConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + pass + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->SigLip +class SigLipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`SigLipEncoderLayer`]. + + Args: + config: SigLipVisionConfig + """ + + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([SigLipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions) + + +class SigLipVisionTransformer(nn.Module): + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SigLipVisionEmbeddings(config) + self.encoder = SigLipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.head = SigLipMultiheadAttentionPoolingHead(config) + + def forward( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = self.head(last_hidden_state) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class SigLipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: SigLipVisionConfig): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SigLipMLP(config) + + def forward(self, hidden_state): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +class SigLipVisionModel(SigLipPreTrainedModel): + config_class = SigLipVisionConfig + main_input_name = "pixel_values" + _no_split_modules = ["SigLipEncoderLayer"] + + def __init__(self, config: SigLipVisionConfig): + super().__init__(config) + + self.vision_model = SigLipVisionTransformer(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + def forward( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, SigLipVisionModel + + >>> model = SigLipVisionModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled features + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class SigLipVisionTower(nn.Module): + def __init__(self, vision_tower, vision_tower_cfg, delay_load=False): + super().__init__() + + self.is_loaded = False + + self.config = SigLipVisionConfig() + + self.vision_tower_name = vision_tower + + self.image_processor = SigLipImageProcessor() + + if not delay_load: + # rank0_print(f"Loading vision tower: {vision_tower}") + self.load_model() + elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False): + # TODO: better detector is needed. + # rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") + self.load_model() + elif hasattr(vision_tower_cfg, "mm_tunable_parts") and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts: + # rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") + self.load_model() + else: + self.cfg_only = self.config + + def load_model(self, device_map=None): + if self.is_loaded: + # rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name)) + return + + self.vision_tower = SigLipVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) + + # del self.vision_tower.vision_model.encoder.layers[-1:] + self.vision_tower.vision_model.head = nn.Identity() + self.vision_tower.requires_grad_(False) + + self.is_loaded = True + + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) + # image_feature = image_forward_out.hidden_states[-1].to(image.dtype) + image_feature = image_forward_out.last_hidden_state.to(image.dtype) + assert image_features.shape[-2] == 729 + image_features.append(image_feature) + else: + image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) + # image_features = image_forward_outs.hidden_states[-1].to(images.dtype) + image_features = image_forward_outs.last_hidden_state.to(images.dtype) + assert image_features.shape[-2] == 729 + + return image_features + + @property + def dummy_feature(self): + return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + for p in self.vision_tower.parameters(): + return p.dtype + + @property + def device(self): + for p in self.vision_tower.parameters(): + return p.device + + @property + def hidden_size(self): + return self.config.hidden_size + + @property + def num_patches(self): + return (self.config.image_size // self.config.patch_size) ** 2 + + @property + def num_patches_per_side(self): + return self.config.image_size // self.config.patch_size + # return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"] + + @property + def image_size(self): + return self.config.image_size \ No newline at end of file diff --git a/blip3o/model/multimodal_projector/builder.py b/blip3o/model/multimodal_projector/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..5d5b9702ac3469d39eca3ee54fc591d597a43cd5 --- /dev/null +++ b/blip3o/model/multimodal_projector/builder.py @@ -0,0 +1,78 @@ +import torch +import torch.nn as nn +import re + + +class IdentityMap(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, *args, **kwargs): + return x + + @property + def config(self): + return {"mm_projector_type": 'identity'} + + +class SimpleResBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.pre_norm = nn.LayerNorm(channels) + + self.proj = nn.Sequential( + nn.Linear(channels, channels), + nn.GELU(), + nn.Linear(channels, channels) + ) + def forward(self, x): + x = self.pre_norm(x) + return x + self.proj(x) + + +def build_vision_projector(config, delay_load=False, **kwargs): + projector_type = getattr(config, 'mm_projector_type', 'linear') + + if projector_type == 'linear': + return nn.Linear(config.mm_hidden_size, config.hidden_size) + + mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) + if mlp_gelu_match: + mlp_depth = int(mlp_gelu_match.group(1)) + modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(config.hidden_size, config.hidden_size)) + return nn.Sequential(*modules) + + if projector_type == 'identity': + return IdentityMap() + + raise ValueError(f'Unknown projector type: {projector_type}') + + + +def build_gen_vision_projector(config, delay_load=False, **kwargs): + projector_type = getattr(config, 'gen_projector_type', 'linear') + + if projector_type == 'linear': + return nn.Linear(config.gen_hidden_size, config.hidden_size) + + mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) + if mlp_gelu_match: + mlp_depth = int(mlp_gelu_match.group(1)) + modules = [nn.Linear(config.gen_hidden_size, config.hidden_size)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(config.hidden_size, config.hidden_size)) + return nn.Sequential(*modules) + + if projector_type == 'identity': + return IdentityMap() + + raise ValueError(f'Unknown projector type: {projector_type}') + + +def build_down_projector(config, delay_load=False, **kwargs): + + return IdentityMap() diff --git a/blip3o/model/nextdit_crossattn.py b/blip3o/model/nextdit_crossattn.py new file mode 100644 index 0000000000000000000000000000000000000000..277011cea2385bf6c6ea4f551a407f8c02f036e6 --- /dev/null +++ b/blip3o/model/nextdit_crossattn.py @@ -0,0 +1,95 @@ +from typing import Optional + +import torch +from diffusers.models.embeddings import get_2d_rotary_pos_embed_lumina +from transformers import PretrainedConfig, PreTrainedModel + +from blip3o.model.lumina_nextdit2d import LuminaNextDiT2DModel + + +class NextDiTCrossAttnConfig(PretrainedConfig): + model_type = "nextdit-crossattn" + + def __init__( + self, + input_size: int = 8, + patch_size: int = 1, + in_channels: int = 1792, + dim: int = 1792, + n_layers: int = 24, + n_heads: int = 28, + n_kv_heads: int = 28, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + norm_eps: float = 1e-5, + latent_embedding_size: int = 3584, + learn_sigma: bool = False, + qk_norm: bool = True, + _gradient_checkpointing: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self.input_size = input_size + self.patch_size = patch_size + self.in_channels = in_channels + self.dim = dim + self.n_layers = n_layers + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.multiple_of = multiple_of + self.ffn_dim_multiplier = ffn_dim_multiplier + self.norm_eps = norm_eps + self.learn_sigma = learn_sigma + self.qk_norm = qk_norm + self.latent_embedding_size = latent_embedding_size + self._gradient_checkpointing = _gradient_checkpointing + + +class NextDiTCrossAttn(PreTrainedModel): + config_class = NextDiTCrossAttnConfig + + def __init__( + self, + config: NextDiTCrossAttnConfig, + ) -> None: + super().__init__(config) + assert config.learn_sigma is False, "learn_sigma is not supported in nextdit-crossattn" + self._gradient_checkpointing = config._gradient_checkpointing + + self.model = LuminaNextDiT2DModel( + sample_size=config.input_size, + patch_size=config.patch_size, + in_channels=config.in_channels, + hidden_size=config.dim, + num_layers=config.n_layers, + num_attention_heads=config.n_heads, + num_kv_heads=config.n_kv_heads, + multiple_of=config.multiple_of, + ffn_dim_multiplier=config.ffn_dim_multiplier, + norm_eps=config.norm_eps, + learn_sigma=config.learn_sigma, + qk_norm=config.qk_norm, + cross_attention_dim=config.latent_embedding_size, + ) + + if self._gradient_checkpointing: + self.model.enable_gradient_checkpointing() + + # self.model.requires_grad_(False) + + self.freqs_cis = get_2d_rotary_pos_embed_lumina( + config.dim // config.n_heads, + 384, + 384, + ) + + def forward(self, x, timestep, z_latents, **kwargs): + model_pred = self.model( + hidden_states=x, + timestep=timestep, + encoder_hidden_states=z_latents, + encoder_mask=torch.ones((z_latents.shape[0], z_latents.shape[1]), device=z_latents.device), + image_rotary_emb=self.freqs_cis, + cross_attention_kwargs=dict(), + ).sample + return model_pred diff --git a/blip3o/model/utils.py b/blip3o/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7ef2a50fdb87249acfff6491b59cb617700d5412 --- /dev/null +++ b/blip3o/model/utils.py @@ -0,0 +1,20 @@ +from transformers import AutoConfig + + +def auto_upgrade(config): + cfg = AutoConfig.from_pretrained(config) + if 'blip3o' in config and 'blip3o' not in cfg.model_type: + assert cfg.model_type == 'llama' + print("You are using newer blip3o code base, while the checkpoint of v0 is from older code base.") + print("You must upgrade the checkpoint to the new code base (this can be done automatically).") + confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") + if confirm.lower() in ["y", "yes"]: + print("Upgrading checkpoint...") + assert len(cfg.architectures) == 1 + setattr(cfg.__class__, "model_type", "blip3o") + cfg.architectures[0] = 'blip3oLlamaForCausalLM' + cfg.save_pretrained(config) + print("Checkpoint upgraded.") + else: + print("Checkpoint upgrade aborted.") + exit(1) diff --git a/blip3o/train/blip3o_trainer.py b/blip3o/train/blip3o_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..f817684cbbba3f75f8f94694fdc5e9b8d892a65e --- /dev/null +++ b/blip3o/train/blip3o_trainer.py @@ -0,0 +1,288 @@ +import os +import torch +import torch.nn as nn +import numpy as np +from torch.utils.data import Sampler + +from transformers import Trainer +from transformers.trainer import ( + is_sagemaker_mp_enabled, + get_parameter_names, + has_length, + ALL_LAYERNORM_LAYERS, + logger, +) +from typing import List, Optional +from transformers.utils import is_torch_xla_available + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + import torch_xla.debug.metrics as met + from torch_xla import __version__ as XLA_VERSION + + IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION) + if IS_XLA_FSDPV2_POST_2_2: + import torch_xla.distributed.spmd as xs + import torch_xla.runtime as xr +else: + IS_XLA_FSDPV2_POST_2_2 = False + + +def maybe_zero_3(param, ignore_status=False, name=None): + from deepspeed import zero + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + print(name, "no ignore status") + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): + to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} + to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} + return to_return + + +def split_to_even_chunks(indices, lengths, num_chunks): + """ + Split a list of indices into `chunks` chunks of roughly equal lengths. + """ + + if len(indices) % num_chunks != 0: + return [indices[i::num_chunks] for i in range(num_chunks)] + + num_indices_per_chunk = len(indices) // num_chunks + + chunks = [[] for _ in range(num_chunks)] + chunks_lengths = [0 for _ in range(num_chunks)] + for index in indices: + shortest_chunk = chunks_lengths.index(min(chunks_lengths)) + chunks[shortest_chunk].append(index) + chunks_lengths[shortest_chunk] += lengths[index] + if len(chunks[shortest_chunk]) == num_indices_per_chunk: + chunks_lengths[shortest_chunk] = float("inf") + + return chunks + + +def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + assert all(l != 0 for l in lengths), "Should not have zero length." + if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): + # all samples are in the same modality + return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator) + mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) + lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) + + mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)] + lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)] + megabatch_size = world_size * batch_size + mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)] + lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)] + + last_mm = mm_megabatches[-1] + last_lang = lang_megabatches[-1] + additional_batch = last_mm + last_lang + megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] + megabatch_indices = torch.randperm(len(megabatches), generator=generator) + megabatches = [megabatches[i] for i in megabatch_indices] + + if len(additional_batch) > 0: + megabatches.append(sorted(additional_batch)) + + return [i for megabatch in megabatches for i in megabatch] + + +def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + indices = torch.randperm(len(lengths), generator=generator) + megabatch_size = world_size * batch_size + megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] + megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] + megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] + + return [i for megabatch in megabatches for batch in megabatch for i in batch] + + +class LengthGroupedSampler(Sampler): + r""" + Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while + keeping a bit of randomness. + """ + + def __init__( + self, + batch_size: int, + world_size: int, + lengths: Optional[List[int]] = None, + generator=None, + group_by_modality: bool = False, + ): + if lengths is None: + raise ValueError("Lengths must be provided.") + + self.batch_size = batch_size + self.world_size = world_size + self.lengths = lengths + self.generator = generator + self.group_by_modality = group_by_modality + + def __len__(self): + return len(self.lengths) + + def __iter__(self): + if self.group_by_modality: + indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) + else: + indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) + return iter(indices) + + +class blip3oTrainer(Trainer): + + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.train_dataset is None or not has_length(self.train_dataset): + return None + + if self.args.group_by_modality_length: + lengths = self.train_dataset.modality_lengths + return LengthGroupedSampler( + self.args.train_batch_size, + world_size=self.args.world_size * self.args.gradient_accumulation_steps, + lengths=lengths, + group_by_modality=True, + ) + else: + return super()._get_train_sampler() + + # def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time): + # if not hasattr(self, "largest_loss"): + # self.largest_loss = tr_loss.item() + # self.largest_grad_norm = grad_norm + # self.latest_grad_norm = grad_norm + # else: + # if tr_loss.item() > 10 * self.largest_loss: + # print(f"Loss Spiked: {tr_loss.item()} -> {self.largest_loss}") + # self.control.should_training_stop = True + # if grad_norm > 10 * self.latest_grad_norm and grad_norm > 3: + # print(f"Grad Norm Spiked: {grad_norm} -> {self.latest_grad_norm}") + # self.control.should_training_stop = True + # self.largest_loss = max(tr_loss.item(), self.largest_loss) + # self.largest_grad_norm = max(grad_norm, self.largest_grad_norm) + # self.latest_grad_norm = grad_norm + + # if np.isnan(grad_norm) or grad_norm > 1e6: + # print(f"NaN grad norm detected in process {self.args.process_index} on {os.uname().nodename}") + # self.control.should_training_stop = True + # print(f"Shut Down Training") + + # if self.control.should_log and self.state.global_step > self._globalstep_last_logged: + # if is_torch_xla_available(): + # xm.mark_step() + + # logs: Dict[str, float] = {} + + # # all_gather + mean() to get average loss over all processes + # tr_loss_scalar = self._nested_gather(tr_loss).mean().item() + + # # reset tr_loss to zero + # tr_loss -= tr_loss + + # logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + # if grad_norm is not None: + # logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm + # logs["learning_rate"] = self._get_learning_rate() + + # self._total_loss_scalar += tr_loss_scalar + # self._globalstep_last_logged = self.state.global_step + # self.store_flos() + + # self.log(logs, start_time) + + # metrics = None + # if self.control.should_evaluate: + # metrics = self._evaluate(trial, ignore_keys_for_eval) + # is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial) + + # if self.args.save_strategy == SaveStrategy.BEST: + # self.control.should_save = is_new_best_metric + + # if self.control.should_save: + # self._save_checkpoint(model, trial) + # self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + + def create_optimizer(self): + """ + Setup the optimizer. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method in a subclass. + """ + if is_sagemaker_mp_enabled(): + return super().create_optimizer() + + opt_model = self.model + + if self.optimizer is None: + decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + if self.args.mm_projector_lr is not None: + projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)], + "weight_decay": self.args.weight_decay, + }, + { + "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)], + "weight_decay": 0.0, + }, + { + "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)], + "weight_decay": self.args.weight_decay, + "lr": self.args.mm_projector_lr, + }, + { + "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)], + "weight_decay": 0.0, + "lr": self.args.mm_projector_lr, + }, + ] + else: + optimizer_grouped_parameters = [ + { + "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)], + "weight_decay": self.args.weight_decay, + }, + { + "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)], + "weight_decay": 0.0, + }, + ] + + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) + + self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + if optimizer_cls.__name__ == "Adam8bit": + import bitsandbytes + + manager = bitsandbytes.optim.GlobalOptimManager.get_instance() + + skipped = 0 + for module in opt_model.modules(): + if isinstance(module, nn.Embedding): + skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) + logger.info(f"skipped {module}: {skipped/2**20}M params") + manager.register_module_override(module, "weight", {"optim_bits": 32}) + logger.debug(f"bitsandbytes: will optimize {module} in fp32") + logger.info(f"skipped: {skipped/2**20}M params") + + return self.optimizer + diff --git a/blip3o/train/llama_flash_attn_monkey_patch.py b/blip3o/train/llama_flash_attn_monkey_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..31db2eff8d1c4b3ae645583dfc5e156e818b6f1c --- /dev/null +++ b/blip3o/train/llama_flash_attn_monkey_patch.py @@ -0,0 +1,115 @@ +from typing import Optional, Tuple +import warnings + +import torch + +import transformers +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv + +try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func +except ImportError: + from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func +from flash_attn.bert_padding import unpad_input, pad_input + + +def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + warnings.warn( + "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_key_value_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_key_value_heads, self.head_dim) + .transpose(1, 2) + ) # shape: (b, num_heads, s, head_dim) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + if past_key_value is not None: + # reuse k, v + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # Transform the data into the format required by flash attention + qkv = torch.stack([query_states, key_states, value_states], dim=2) + qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] + key_padding_mask = attention_mask + + if key_padding_mask is None: + qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) + cu_q_lens = torch.arange( + 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device + ) + max_s = q_len + output = flash_attn_unpadded_qkvpacked_func( + qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True + ) + output = output.view(bsz, q_len, -1) + else: + qkv = qkv.reshape(bsz, q_len, -1) + qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) + qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) + output_unpad = flash_attn_unpadded_qkvpacked_func( + qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True + ) + output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) + output = pad_input(output_unpad, indices, bsz, q_len) + + return self.o_proj(output), None, past_key_value + + +# Disable the transformation of the attention mask in LlamaModel as the flash attention +# requires the attention mask to be the same as the key_padding_mask +def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length +): + # [bsz, seq_len] + return attention_mask + + +def replace_llama_attn_with_flash_attn(): + cuda_major, cuda_minor = torch.cuda.get_device_capability() + if cuda_major < 8: + warnings.warn( + "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." + "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" + ) + transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( + _prepare_decoder_attention_mask + ) + transformers.models.llama.modeling_llama.LlamaAttention.forward = forward diff --git a/blip3o/train/llama_xformers_attn_monkey_patch.py b/blip3o/train/llama_xformers_attn_monkey_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..f8351e41ccd4a64dca237bd8f8be0702b23989dc --- /dev/null +++ b/blip3o/train/llama_xformers_attn_monkey_patch.py @@ -0,0 +1,129 @@ +""" +Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments +""" + +import logging +import math +from typing import Optional, Tuple + +import torch +import transformers.models.llama.modeling_llama +from torch import nn + +try: + import xformers.ops +except ImportError: + logging.error("xformers not found! Please install it before trying to use it.") + + +def replace_llama_attn_with_xformers_attn(): + transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward + + +def xformers_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # pylint: disable=duplicate-code + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + ( + query_states, + key_states, + ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # We only apply xformers optimizations if we don't need to output the whole attention matrix + if not output_attentions: + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. + # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. + if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: + # input and output should be of form (bsz, q_len, num_heads, head_dim) + attn_output = xformers.ops.memory_efficient_attention( + query_states, key_states, value_states, attn_bias=None + ) + else: + # input and output should be of form (bsz, q_len, num_heads, head_dim) + attn_output = xformers.ops.memory_efficient_attention( + query_states, + key_states, + value_states, + attn_bias=xformers.ops.LowerTriangularMask(), + ) + attn_weights = None + else: + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) + ) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value diff --git a/blip3o/train/train_mem.py b/blip3o/train/train_mem.py new file mode 100644 index 0000000000000000000000000000000000000000..19b1ea1df190e6b51801c745a3eaeb26baf22071 --- /dev/null +++ b/blip3o/train/train_mem.py @@ -0,0 +1,4 @@ +from blip3o.train.train import train + +if __name__ == "__main__": + train(attn_implementation="flash_attention_2") diff --git a/blip3o/train/train_xformers.py b/blip3o/train/train_xformers.py new file mode 100644 index 0000000000000000000000000000000000000000..08b939dbd8df66bfc26a8f13eb62af6fe1c146a9 --- /dev/null +++ b/blip3o/train/train_xformers.py @@ -0,0 +1,10 @@ +from blip3o.train.llama_xformers_attn_monkey_patch import ( + replace_llama_attn_with_xformers_attn, +) + +replace_llama_attn_with_xformers_attn() + +from blip3o.train.train import train + +if __name__ == "__main__": + train() diff --git a/blip3o/utils.py b/blip3o/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8345bb6c9f7fd046017060a540e7d12e33141381 --- /dev/null +++ b/blip3o/utils.py @@ -0,0 +1,126 @@ +import datetime +import logging +import logging.handlers +import os +import sys + +import requests + +from blip3o.constants import LOGDIR + +server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" +moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." + +handler = None + + +def build_logger(logger_name, logger_filename): + global handler + + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Set the format of root handlers + if not logging.getLogger().handlers: + logging.basicConfig(level=logging.INFO) + logging.getLogger().handlers[0].setFormatter(formatter) + + # Redirect stdout and stderr to loggers + stdout_logger = logging.getLogger("stdout") + stdout_logger.setLevel(logging.INFO) + sl = StreamToLogger(stdout_logger, logging.INFO) + sys.stdout = sl + + stderr_logger = logging.getLogger("stderr") + stderr_logger.setLevel(logging.ERROR) + sl = StreamToLogger(stderr_logger, logging.ERROR) + sys.stderr = sl + + # Get logger + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + + # Add a file handler for all loggers + if handler is None: + os.makedirs(LOGDIR, exist_ok=True) + filename = os.path.join(LOGDIR, logger_filename) + handler = logging.handlers.TimedRotatingFileHandler( + filename, when='D', utc=True, encoding='UTF-8') + handler.setFormatter(formatter) + + for name, item in logging.root.manager.loggerDict.items(): + if isinstance(item, logging.Logger): + item.addHandler(handler) + + return logger + + +class StreamToLogger(object): + """ + Fake file-like stream object that redirects writes to a logger instance. + """ + def __init__(self, logger, log_level=logging.INFO): + self.terminal = sys.stdout + self.logger = logger + self.log_level = log_level + self.linebuf = '' + + def __getattr__(self, attr): + return getattr(self.terminal, attr) + + def write(self, buf): + temp_linebuf = self.linebuf + buf + self.linebuf = '' + for line in temp_linebuf.splitlines(True): + # From the io.TextIOWrapper docs: + # On output, if newline is None, any '\n' characters written + # are translated to the system default line separator. + # By default sys.stdout.write() expects '\n' newlines and then + # translates them so this is still cross platform. + if line[-1] == '\n': + self.logger.log(self.log_level, line.rstrip()) + else: + self.linebuf += line + + def flush(self): + if self.linebuf != '': + self.logger.log(self.log_level, self.linebuf.rstrip()) + self.linebuf = '' + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def violates_moderation(text): + """ + Check whether the text violates OpenAI moderation API. + """ + url = "https://api.openai.com/v1/moderations" + headers = {"Content-Type": "application/json", + "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} + text = text.replace("\n", "") + data = "{" + '"input": ' + f'"{text}"' + "}" + data = data.encode("utf-8") + try: + ret = requests.post(url, headers=headers, data=data, timeout=5) + flagged = ret.json()["results"][0]["flagged"] + except requests.exceptions.RequestException as e: + flagged = False + except KeyError as e: + flagged = False + + return flagged + + +def pretty_print_semaphore(semaphore): + if semaphore is None: + return "None" + return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..024b590bec57e55e7450f91552d976c40c917dd0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,35 @@ +-e . + +torch==2.3.0 +torchvision==0.18.0 +torchaudio==2.3.0 +xformers==0.0.26.post1 +tokenizers +sentencepiece +shortuuid +transformers==4.51.3 +accelerate +peft +bitsandbytes +pydantic +markdown2[all] +numpy +scikit-learn +gradio==4.16.0 +gradio_client==0.8.1 +requests +httpx==0.24.0 +uvicorn +fastapi +einops==0.6.1 +einops-exts==0.0.4 +timm==0.6.13 +ftfy +git+https://github.com/JiuhaiChen/diffusers.git +deepspeed +datasets +tabulate +flash_attn==2.6.2 +ninja +qwen_vl_utils +huggingface_hub diff --git a/run.sh b/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..698bc1403f49d6e5909b83208bc22771cd902643 --- /dev/null +++ b/run.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +conda activate blip3o + + +export HF_HOME=/HF/Home/ +export OUTPUT_FOLDER=/Your/Model/Output/ +export IMG_FOLDER=/Your/Image/Folder + + +torchrun --nproc_per_node=8 \ + blip3o/train/train_mem.py \ + --deepspeed ./deepspeed_scripts/zero1.json \ + --model_name_or_path Qwen/Qwen2.5-VL-7B-Instruct \ + --version qwen \ + --data_type "mix" \ + --image_folder ${IMG_FOLDER} \ + --gen_vision_tower eva-clip-E-14-plus \ + --gen_projector_type mlp2x_gelu \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --bf16 True \ + --output_dir ${OUTPUT_FOLDER} \ + --num_train_epochs 1 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --eval_strategy "no" \ + --save_strategy "steps" \ + --save_steps 1000 \ + --save_total_limit 1 \ + --learning_rate 1e-4 \ + --weight_decay 0. \ + --warmup_ratio 0.003 \ + --lr_scheduler_type "cosine_with_min_lr" \ + --lr_scheduler_kwargs '{"min_lr":1e-5}' \ + --model_max_length 512 \ + --logging_steps 1 \ + --tf32 True \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --gen_pooling early_pool2d_4 \ + --n_query 64 \ + --n_und_query 0 \ + --report_to none \ + --run_name blip3o_qwen_vl_7b + + + + diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..8ac21da1af313d64d46eb3c26bebc7327601e10b --- /dev/null +++ b/setup.py @@ -0,0 +1,7 @@ +from setuptools import setup, find_packages + +setup( + name="blip3o", + version="0.1.0", + packages=find_packages(), +)