diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..00edda492e150cd13281b0b37a2b0eafd4e483e7 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,23 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +llava/model/coat/optimizer/kernels/build/lib.linux-x86_64-cpython-310/qoptim_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/bindings.o filter=lfs diff=lfs merge=lfs -text +llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda.o filter=lfs diff=lfs merge=lfs -text +llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda.o filter=lfs diff=lfs merge=lfs -text +static/af3_main_diagram-1.png filter=lfs diff=lfs merge=lfs -text +static/af3_radial-1.png filter=lfs diff=lfs merge=lfs -text +static/af3_sota.png filter=lfs diff=lfs merge=lfs -text +static/audio/audio2.wav filter=lfs diff=lfs merge=lfs -text +static/chat/audio1.mp3 filter=lfs diff=lfs merge=lfs -text +static/chat/audio2.mp3 filter=lfs diff=lfs merge=lfs -text +static/emergent/audio1.wav filter=lfs diff=lfs merge=lfs -text +static/logo-no-bg.png filter=lfs diff=lfs merge=lfs -text +static/speech/339a1acd-afcb-466b-a7b1-8661e59b1e56.wav filter=lfs diff=lfs merge=lfs -text +static/speech/audio3.wav filter=lfs diff=lfs merge=lfs -text +static/speech/bcc6057d-0dda-435d-b956-a96ab27bc9e4.wav filter=lfs diff=lfs merge=lfs -text +static/speech/be84d293-5e9c-4158-9a1e-b4dd1acb7d70.wav filter=lfs diff=lfs merge=lfs -text +static/speech/fec3402e-7883-45c0-90d4-38647f615dc3.wav filter=lfs diff=lfs merge=lfs -text +static/think/audio1.wav filter=lfs diff=lfs merge=lfs -text +static/think/audio2.wav filter=lfs diff=lfs merge=lfs -text +static/voice/voice_2.mp3 filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index d27df3b03f3c9bcb6083a4116fd414c37d3117fa..df539d80c9ed0e13695926463a336d777f263e06 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,139 @@ ---- -title: Audio Flamingo 3 -emoji: ⚡ -colorFrom: pink -colorTo: purple -sdk: gradio -sdk_version: 5.36.2 -app_file: app.py -pinned: false -license: other -short_description: Online demo for Audio Flamingo 3 ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference + +
+ + Audio Flamingo 3 🔥🚀🔥 + +
+
+

+ Audio Flamingo 3: Advancing Audio Intelligence with Fully Open Large Audio-Language Models +

+
+ +
+ + + + +
+ +
+ + + + + + + + + + + + + + + + + + +
+ +
+ +
+ +## Overview + +This repo contains the PyTorch implementation of [Audio Flamingo 3: Advancing Audio Intelligence with Fully Open Large Audio-Language Models](). Audio Flamingo 3 (AF3) is a fully open, state-of-the-art Large Audio-Language Model (LALM) that advances reasoning and understanding across speech, sounds, and music. AF3 builds on previous work with innovations in: + +- Unified audio representation learning (speech, sound, music) +- Flexible, on-demand chain-of-thought reasoning (Thinking in Audio) +- Long-context audio comprehension (including speech and up to 10 minutes) +- Multi-turn, multi-audio conversational dialogue (AF3-Chat) +- Voice-to-voice interaction (AF3-Chat) + +Extensive evaluations confirm AF3’s effectiveness, setting new benchmarks on over 20 public audio understanding and reasoning tasks. + + +## Main Results + +Audio Flamingo 3 outperforms prior SOTA models including GAMA, Audio Flamingo, Audio Flamingo 2, Qwen-Audio, Qwen2-Audio, Qwen2.5-Omni.LTU, LTU-AS, SALMONN, AudioGPT, Gemini Flash v2 and Gemini Pro v1.5 on a number of understanding and reasoning benchmarks. + +
+ +
+ +
+ +
+ +## Audio Flamingo 3 Architecture + +Audio Flamingo 3 uses AF-Whisper unified audio encoder, MLP-based audio adaptor, Decoder-only LLM backbone (Qwen2.5-7B), and Streaming TTS module (AF3-Chat). +Audio Flamingo 3 can take up to 10 minutes of audio inputs. + +
+ +
+ +## Installation + +```bash +./environment_setup.sh af3 +``` + +## Code Structure + +- The folder ```audio_flamingo_3/``` contains the main training and inference code of Audio Flamingo 3. +- The folder ```audio_flamingo_3/scripts``` contains the inference scripts of Audio Flamingo 3 in case you would like to use our pretrained checkpoints on HuggingFace. + +Each folder is self-contained and we expect no cross dependencies between these folders. This repo does not contain the code for Streaming-TTS pipeline which will released in the near future. + +## Single Line Inference + +To infer stage 3 model directly, run the command below: +```bash +python llava/cli/infer_audio.py --model-base /path/to/checkpoint/af3-7b --conv-mode auto --text "Please describe the audio in detail" --media static/audio1.wav +``` + +To infer the model in stage 3.5 model, run the command below: +```bash +python llava/cli/infer_audio.py --model-base /path/to/checkpoint/af3-7b --model-path /path/to/checkpoint/af3-7b/stage35 --conv-mode auto --text "Please describe the audio in detail" --media static/audio1.wav --peft-mode +``` + +## References + +The main training and inferencing code within each folder are modified from [NVILA](https://github.com/NVlabs/VILA/tree/main) [Apache license](incl_licenses/License_1.md). + +## License + +- The code in this repo is under [MIT license](incl_licenses/MIT_license.md). +- The checkpoints are for non-commercial use only [NVIDIA OneWay Noncommercial License](incl_licenses/NVIDIA_OneWay_Noncommercial_License.docx). They are also subject to the [Qwen Research license](https://huggingface.co/Qwen/Qwen2.5-7B/blob/main/LICENSE), the [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and the original licenses accompanying each training dataset. +- Notice: Audio Flamingo 3 is built with Qwen-2.5. Qwen is licensed under the Qwen RESEARCH LICENSE AGREEMENT, Copyright (c) Alibaba Cloud. All Rights Reserved. + + +## Citation + +- Audio Flamingo 2 +``` +@article{ghosh2025audio, + title={Audio Flamingo 2: An Audio-Language Model with Long-Audio Understanding and Expert Reasoning Abilities}, + author={Ghosh, Sreyan and Kong, Zhifeng and Kumar, Sonal and Sakshi, S and Kim, Jaehyeon and Ping, Wei and Valle, Rafael and Manocha, Dinesh and Catanzaro, Bryan}, + journal={arXiv preprint arXiv:2503.03983}, + year={2025} +} +``` + +- Audio Flamingo +``` +@inproceedings{kong2024audio, + title={Audio Flamingo: A Novel Audio Language Model with Few-Shot Learning and Dialogue Abilities}, + author={Kong, Zhifeng and Goel, Arushi and Badlani, Rohan and Ping, Wei and Valle, Rafael and Catanzaro, Bryan}, + booktitle={International Conference on Machine Learning}, + pages={25125--25148}, + year={2024}, + organization={PMLR} +} +``` + + diff --git a/llava/__init__.py b/llava/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c14ac15313d1fceee30d9a3ef2d69bbba5702722 --- /dev/null +++ b/llava/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from .entry import * +from .media import * diff --git a/llava/cli/infer_audio.py b/llava/cli/infer_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..8041b2ab9aa52fc166c87ce4527ad5a8356857da --- /dev/null +++ b/llava/cli/infer_audio.py @@ -0,0 +1,88 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +import argparse +import importlib.util +import json +import os + +from pydantic import BaseModel +from termcolor import colored + +import llava +from llava import conversation as clib +from llava.media import Image, Video, Sound +from llava.model.configuration_llava import JsonSchemaResponseFormat, ResponseFormat +from peft import PeftModel +import torch + +def get_schema_from_python_path(path: str) -> str: + schema_path = os.path.abspath(path) + spec = importlib.util.spec_from_file_location("schema_module", schema_path) + schema_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(schema_module) + + # Get the Main class from the loaded module + Main = schema_module.Main + assert issubclass( + Main, BaseModel + ), f"The provided python file {path} does not contain a class Main that describes a JSON schema" + return Main.schema_json() + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--model-base", "-mb", type=str, required=True) + parser.add_argument("--model-path", "-mp", type=str, required=True) + parser.add_argument("--conv-mode", "-c", type=str, default="auto") + parser.add_argument("--text", type=str) + parser.add_argument("--media", type=str, nargs="+") + parser.add_argument("--json-mode", action="store_true") + parser.add_argument("--peft-mode", action="store_true") + parser.add_argument("--json-schema", type=str, default=None) + args = parser.parse_args() + + # Convert json mode to response format + if not args.json_mode: + response_format = None + elif args.json_schema is None: + response_format = ResponseFormat(type="json_object") + else: + schema_str = get_schema_from_python_path(args.json_schema) + print(schema_str) + response_format = ResponseFormat(type="json_schema", json_schema=JsonSchemaResponseFormat(schema=schema_str)) + + # Load model + model = llava.load(args.model_base) + if args.peft_mode: + model = PeftModel.from_pretrained( + model, + args.model_path, + device_map="auto", + torch_dtype=torch.float16, + ) + # Set conversation mode + clib.default_conversation = clib.conv_templates[args.conv_mode].copy() + + # Prepare multi-modal prompt + prompt = [] + if args.media is not None: + for media in args.media or []: + if any(media.endswith(ext) for ext in [".wav",".mp3", ".flac"]): + media = Sound(media) + else: + raise ValueError(f"Unsupported media type: {media}") + prompt.append(media) + if args.text is not None: + prompt.append(args.text) + + # Generate response + response = model.generate_content(prompt, response_format=response_format) + print(colored(response, "cyan", attrs=["bold"])) + + +if __name__ == "__main__": + main() diff --git a/llava/constants.py b/llava/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..e4e557b88bf7437b49bb88c754d71218ed79dfae --- /dev/null +++ b/llava/constants.py @@ -0,0 +1,55 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +# This file is modified from https://github.com/haotian-liu/LLaVA/ + +CONTROLLER_HEART_BEAT_EXPIRATION = 30 +WORKER_HEART_BEAT_INTERVAL = 15 + +LOGDIR = "." + +# Model Constants +IGNORE_INDEX = -100 +DEFAULT_SOUND_TOKEN = "" +DEFAULT_SPEECH_TOKEN = "" +SENTINEL_TOKEN = "" + +MEDIA_TOKENS = { + "speech": "", + "sound": "", +} + + +""" +151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), +151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), +151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), +151646: AddedToken("[BOS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), +151647: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), +151648: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), +151649: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), +151650: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), +151651: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), +151652: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), + +""" +NUM_EXTRA_TOKENS = 10 diff --git a/llava/conversation.py b/llava/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..5b42e5bf88571231b262040d3601e63006e125e9 --- /dev/null +++ b/llava/conversation.py @@ -0,0 +1,197 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +# This file is modified from https://github.com/haotian-liu/LLaVA/ + +import dataclasses +from enum import Enum, auto +from typing import List + +from llava.utils.logging import logger + + +class SeparatorStyle(Enum): + """Different separator style.""" + + AUTO = auto() + TWO = auto() + MPT = auto() + PLAIN = auto() + LLAMA_3 = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + + system: str + roles: List[str] + messages: List[List[str]] + sep_style: SeparatorStyle = SeparatorStyle.AUTO + sep: str = "###" + sep2: str = None + version: str = "Unknown" + + def get_prompt(self): + messages = self.messages + if len(messages) > 0 and type(messages[0][1]) is tuple: + messages = self.messages.copy() + init_role, init_msg = messages[0].copy() + init_msg = init_msg[0].replace("", "").strip() + messages[0] = (init_role, "\n" + init_msg) + + if self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + elif self.sep_style == SeparatorStyle.LLAMA_3: + ret = self.system + self.sep + for rid, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message = message[0] + sep = self.sep if rid < len(messages) - 1 else self.sep2 + ret += role + message + sep + else: + ret += role + elif self.sep_style == SeparatorStyle.MPT: + ret = self.system + self.sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + self.sep + else: + ret += role + elif self.sep_style == SeparatorStyle.PLAIN: + seps = [self.sep, self.sep2] + ret = self.system + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += message + seps[i % 2] + else: + ret += "" + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + return ret + + def append_message(self, role, message): + self.messages.append([role, message]) + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + version=self.version, + ) + + +conv_auto = Conversation( + system="", + roles=("", ""), + messages=(), + sep_style=SeparatorStyle.AUTO, + sep="\n", +) + +conv_vicuna_v1 = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_llava_plain = Conversation( + system="", + roles=("", ""), + messages=(), + sep_style=SeparatorStyle.PLAIN, + sep="\n", +) + +hermes_2 = Conversation( + system="<|im_start|>system\nAnswer the questions.", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", + messages=(), + version="hermes-2", +) + +# Template added by Yukang. Note (kentang-mit@): sep is <|eot_id|> for official template. +llama_3_chat = Conversation( + system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. " + "You are able to understand the visual content that the user provides, " + "and assist the user with a variety of tasks using natural language.", + roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"), + version="llama_v3", + messages=(), + sep_style=SeparatorStyle.LLAMA_3, + sep="<|eot_id|>", + sep2="<|end_of_text|>", +) + + +default_conversation = conv_auto +conv_templates = { + "auto": conv_auto, + "hermes-2": hermes_2, + "llama_3": llama_3_chat, + "v1": conv_vicuna_v1, + "vicuna_v1": conv_vicuna_v1, + "plain": conv_llava_plain, +} + + +CONVERSATION_MODE_MAPPING = { + "vila1.5-3b": "vicuna_v1", + "vila1.5-8b": "llama_3", + "vila1.5-13b": "vicuna_v1", + "vila1.5-40b": "hermes-2", + "llama-3": "llama_3", + "llama3": "llama_3", +} + + +def auto_set_conversation_mode(model_name_or_path: str) -> str: + global default_conversation + for k, v in CONVERSATION_MODE_MAPPING.items(): + if k in model_name_or_path.lower(): + logger.info(f"Setting conversation mode to `{v}` based on model name/path `{model_name_or_path}`.") + default_conversation = conv_templates[v] + return diff --git a/llava/data/__init__.py b/llava/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22d60dab57c08f90e224f3d95047c5dfbf14de6d --- /dev/null +++ b/llava/data/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from .builder import * +from .dataset import * +from .datasets_mixture import * diff --git a/llava/data/base.py b/llava/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..925a500cf78667b3ab862a72ab1b564e86c2c746 --- /dev/null +++ b/llava/data/base.py @@ -0,0 +1,95 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +import random +from typing import Any, Dict, List + +import torch +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizer + +from llava.mm_utils import dynamic_process_images_and_prompt, dynamic_s2_process_images_and_prompt, process_images +from llava.train.args import DataArguments +from llava.utils.logging import logger +from llava.utils.media import extract_media +from llava.utils.tokenizer import preprocess_conversation + +__all__ = ["BaseDataset"] + +def _process_speech(speech: List[Any], data_args: DataArguments) -> torch.Tensor: + return torch.tensor(speech) + +def _process_sound(sound: List[Any], data_args: DataArguments) -> torch.Tensor: + return torch.tensor(sound) + +def _process_sound_masks(sound_masks: List[Any], data_args: DataArguments) -> torch.Tensor: + return torch.tensor(sound_masks) + + +class BaseDataset(Dataset): + def __init__( + self, + tokenizer: PreTrainedTokenizer, + data_args: DataArguments, + no_system_prompt: bool = False, + **kwargs: Any, + ) -> None: + super().__init__() + self.tokenizer = tokenizer + self.data_args = data_args + self.no_system_prompt = no_system_prompt + self.instances = [] + self.enable_dynamic_res = False + self.enable_dynamic_res_s2 = False + # global_batch_size: int, + self.global_batch_size = kwargs.get("global_batch_size", 1) + + # by default, dataset cls will resample on failure + self.resample_on_failure = kwargs.get("resample_on_failure", True) + + # by default, dataset cls will resample on failure + self.resample_on_failure = kwargs.get("resample_on_failure", True) + + def process(self, instance: Dict[str, Any]) -> List[Dict[str, Any]]: + raise NotImplementedError + + def __getitem__(self, index: int) -> Dict[str, Any]: + instance = self.instances[index] + + try: + # Process instance to conversation + conversation = self.process(instance) + + # Extract media from conversation + media, media_meta = extract_media(conversation, self.data_args) + + if "speech" in media: + processed_speech = _process_speech(media["speech"], self.data_args) + if "sound" in media: + processed_sound = _process_sound(media["sound"], self.data_args) + processed_sound_feature_masks = _process_sound_masks(media_meta["sound_feature_masks"], self.data_args) + processed_sound_embed_masks = _process_sound_masks(media_meta["sound_embed_masks"], self.data_args) + # Prepare "input_ids" and "labels" for training + data = preprocess_conversation(conversation, self.tokenizer, no_system_prompt=self.no_system_prompt) + + if "speech" in media: + data["speech"] = processed_speech + if "sound" in media: + data["sound"] = processed_sound + data["sound_feature_masks"] = processed_sound_feature_masks + data["sound_embed_masks"] = processed_sound_embed_masks + + except Exception as e: + if not self.resample_on_failure: + raise e + else: + logger.exception(f"Error processing instance '{instance}': '{e}'. Resampling.") + return self.__getitem__(random.randint(0, len(self.instances) - 1)) + + return data + + def __len__(self) -> int: + return len(self.instances) diff --git a/llava/data/builder.py b/llava/data/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..871d380c089e5948a58a668846ecf96faaee4afa --- /dev/null +++ b/llava/data/builder.py @@ -0,0 +1,193 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +import os +import os.path as osp +from itertools import chain +from typing import Any, List, Optional + +import torch +import torch.distributed as dist +from hydra.utils import instantiate +from torch.utils.data import ConcatDataset, Dataset +from transformers import PreTrainedTokenizer + +from llava.data.datasets_mixture import DATASETS_LEGACY +from llava.train.args import DataArguments, TrainingArguments +from llava.utils import io +from llava.utils.logging import logger +import time +import numpy as np +__all__ = ["DATASETS", "MIXTURES", "register_datasets", "register_mixtures", "parse_mixture", "build_dataset"] + + +def load_dataset_yaml(name): + fname = f"{name}.yaml" if not name.endswith(".yaml") else name + + # yaml under llava/data/registry/datasets + repo_path = osp.join(osp.dirname(__file__), "registry", "datasets", fname) + if osp.exists(repo_path): + return repo_path + + # # yaml under + abs_path = osp.expanduser(fname) + if osp.exists(abs_path): + return abs_path + + raise FileNotFoundError(f"Dataset '{name}' is not found in the {repo_path} or {abs_path}.") + + +def register_datasets(name: Optional[str] = None): + if name is None: + name = os.environ.get("VILA_DATASETS", "default") + logger.info(f"Registering datasets from environment: '{name}'.") + # return io.load(osp.join(osp.dirname(__file__), "registry", "datasets", f"{name}.yaml")) + dataset_meta = {} + for _name in name.split(","): + yamlpath = load_dataset_yaml(_name) + logger.info(f"Registering datasets from: '{yamlpath}'.") + meta = io.load(yamlpath) + dataset_meta.update(meta) + return dataset_meta + + +def register_mixtures(): + return io.load(os.path.join(os.path.dirname(__file__), "registry", "mixtures.yaml")) + + +DATASETS = register_datasets() +MIXTURES = register_mixtures() + + +def parse_mixture(mixture: str) -> List[str]: + names = mixture.split("+") if "+" in mixture else [mixture] + while any(name in MIXTURES for name in names): + names = list(chain(*[MIXTURES.get(name, [name]) for name in names])) + return sorted(names) + + +class SubsetDataset(Dataset): + def __init__(self, dataset: Dataset, limit: int) -> None: + super().__init__() + self.dataset = dataset + self.limit = limit + + def __len__(self) -> int: + return int(len(self.dataset) * self.limit) + + def __getitem__(self, index: int) -> Any: + return self.dataset[index % len(self.dataset)] + +class RepeatedDataset(Dataset): + def __init__(self, dataset: Dataset, times: int) -> None: + super().__init__() + self.dataset = dataset + self.times = times + + def __len__(self) -> int: + return len(self.dataset) * self.times + + def __getitem__(self, index: int) -> Any: + return self.dataset[index % len(self.dataset)] + + +def get_world_size(): + if torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + else: + return 1 + + +def build_dataset( + mixture: str, + data_args: DataArguments, + training_args: TrainingArguments, + tokenizer: PreTrainedTokenizer, +) -> Dataset: + logger.warning(f"Training VILA with mixture '{mixture}'.") + datasets = [] + dataset_rng = np.random.default_rng(1234) + for name in parse_mixture(mixture): + + if "*" in name: + name, times = name.split("*") + times = int(times) + else: + times = 1 + limit_dataset = False + if "#" in name: + # we limit the max length of this dataset + name, max_length_percent = name.split("#") + limit_dataset = True + if DATASETS is not None and name in DATASETS: + if name in DATASETS_LEGACY: + logger.warning(f"Dataset '{name}' exists in both new and legacy registries. Using the new one.") + dataset = instantiate(DATASETS[name], _partial_=True)( + tokenizer=tokenizer, + data_args=data_args, + global_batch_size=( + training_args.per_device_train_batch_size + # * torch.distributed.get_world_size() + * get_world_size() + * training_args.gradient_accumulation_steps + ), + ) + elif name in DATASETS_LEGACY: + logger.warning(f"Dataset '{name}' is from the legacy registry. Please consider migrating it.") + dataset = build_dataset_legacy( + name, + data_args=data_args, + training_args=training_args, + tokenizer=tokenizer, + ) + else: + raise ValueError(f"Dataset '{name}' is not found in the registries.") + + + if limit_dataset: + # we limit the max length of this dataset + max_length = int(float(int(max_length_percent) / 100.) * len(dataset)) + dataset = SubsetDataset(dataset, float(int(max_length_percent) / 100.)) + + if times > 1: + dataset = RepeatedDataset(dataset, times) + datasets.append(dataset) + return ConcatDataset(datasets) + + +def build_dataset_legacy( + name: str, + data_args: DataArguments, + training_args: TrainingArguments, + tokenizer: PreTrainedTokenizer, +) -> Dataset: + from llava.data.dataset import ( + LazySupervisedDataset, + LazyWDSDataset, + ) + + dataset = DATASETS_LEGACY[name] + dataset_type = dataset.dataset_type + if dataset_type == "torch": + dataset_cls = LazySupervisedDataset + elif dataset_type == "wds": + dataset_cls = LazyWDSDataset + else: + raise NotImplementedError(f"{dataset_type} is not supported.") + + data_args.meta_path = getattr(dataset, "meta_path", None) + data_args.caption_choice = getattr(dataset, "caption_choice", None) + data_args.caption_choice_2 = getattr(dataset, "caption_choice_2", None) + data_args.start_idx = getattr(dataset, "start_idx", None) + data_args.end_idx = getattr(dataset, "end_idx", None) + + return dataset_cls( + tokenizer=tokenizer, + data_path=dataset.data_path, + image_folder=getattr(dataset, "image_path"), + data_args=data_args, + training_args=training_args, + ) diff --git a/llava/data/collate.py b/llava/data/collate.py new file mode 100644 index 0000000000000000000000000000000000000000..4795b77a246acce9294f07b189f0a06725b71f2c --- /dev/null +++ b/llava/data/collate.py @@ -0,0 +1,166 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from dataclasses import dataclass +from typing import Any, Dict, Sequence + +import torch +from transformers import PreTrainedTokenizer + +from llava.constants import IGNORE_INDEX +from llava.utils.logging import logger + +__all__ = ["DataCollator"] + + +@dataclass +class DataCollator: + tokenizer: PreTrainedTokenizer + + def __init__(self, tokenizer: PreTrainedTokenizer): + super().__init__() + self.tokenizer = tokenizer + + def __call__(self, instances: Sequence[Dict[str, Any]]) -> Dict[str, Any]: + # Gather everything from the batch + input_ids, labels, media, block_sizes = [], [], {name: [] for name in self.tokenizer.media_tokens}, [] + + media_meta = {} + + media_meta["sound_feature_masks"] = [] + media_meta["sound_embed_masks"] = [] + media_meta["frame_times"] = [] + for instance in instances: + if isinstance(instance["input_ids"], torch.Tensor): + input_ids.append(instance["input_ids"]) + labels.append(instance["labels"]) + for name in media: + objs = instance.get(name) + objs = objs if objs is not None else [] + media[name].append([obj for obj in objs]) + if instance.get("sound") is not None: + for name_k in media_meta: + if "sound" in name_k: + objs = instance.get(name_k) + media_meta[name_k].append([obj for obj in objs]) + if instance.get("video") is not None or instance.get("image") is not None: + for name_k in media_meta: + if "frame" in name_k: + objs = instance.get(name_k) + media_meta[name_k].append([obj for obj in objs]) + if "block_sizes" in instance: + block_sizes.append(instance["block_sizes"]) + else: + block_sizes.append( + [None for _ in range(len(instance.get("image")))] if instance.get("image") is not None else [] + ) + else: + input_ids.extend(instance["input_ids"]) + labels.extend(instance["labels"]) + for name in media: + objs = instance.get(name) + objs = objs if objs is not None else [[] for _ in range(len(instance["input_ids"]))] + media[name].extend(objs) + if instance.get("sound") is not None: + for name_k in media_meta: + if "sound" in name_k: + objs = instance.get(name_k) + media_meta[name_k].extend(objs) + if instance.get("video") is not None or instance.get("image") is not None: + for name_k in media_meta: + if "frame" in name_k: + objs = instance.get(name_k) + media_meta[name_k].append([obj for obj in objs]) + if "block_sizes" in instance: + block_sizes.extend(instance["block_sizes"]) + else: + block_sizes.extend( + [[None for _ in range(len(objs))] for objs in instance.get("image")] + if instance.get("image") is not None + else [[] for _ in range(len(instance["input_ids"]))] + ) + + batch_size = len(input_ids) + + + # Check if the number of media objects (or the number of block sizes) matches the number of media tokens + for name in media: + for k in range(batch_size): + if name == "image" and not all([_ is None for _ in block_sizes[k]]): + actual = len(block_sizes[k]) + else: + actual = len(media[name][k]) + expected = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item() + if actual != expected: + raise ValueError( + f"Number mismatch between {name} objects and {name} tokens. " + f"There are {expected} {name} tokens but {actual} {name} objects." + ) + + # Batchify the inputs + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id, + ) + labels = torch.nn.utils.rnn.pad_sequence( + labels, + batch_first=True, + padding_value=IGNORE_INDEX, + ) + input_ids = input_ids[:, : self.tokenizer.model_max_length] + labels = labels[:, : self.tokenizer.model_max_length] + attention_mask = input_ids.ne(self.tokenizer.pad_token_id) + + # Truncate media objects if necessary + for name in media: + objects = [] + for k in range(batch_size): + if name == "image" and not all([_ is None for _ in block_sizes[k]]): + actual = len(media[name][k]) + num_large_scale_blocks = sum([x * y for x, y in block_sizes[k]]) + num_small_scale_blocks = actual - num_large_scale_blocks + num_small_scale_blocks_each_img = num_small_scale_blocks // len(block_sizes[k]) + expected_full_image = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item() + expected = ( + sum([x * y for x, y in block_sizes[k][:expected_full_image]]) + + num_small_scale_blocks_each_img * expected_full_image + ) + if actual > expected: + logger.warning(f"Truncating the number of {name} objects from {actual} to {expected}") + media[name][k] = media[name][k][:expected] + objects.extend(media[name][k]) + block_sizes[k] = block_sizes[k][:expected_full_image] + else: + actual = len(media[name][k]) + expected = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item() + if actual > expected: + logger.warning(f"Truncating the number of {name} objects from {actual} to {expected}") + media[name][k] = media[name][k][:expected] + objects.extend(media[name][k]) + if name == "image": + block_sizes[k] = block_sizes[k][:expected] + media[name] = objects + + for name in media_meta: + objects = [] + for k in range(batch_size): + try: + objects.extend(media_meta[name][k]) + except: + continue + media_meta[name] = objects + + # Flatten block sizes from [[bls_im1_instance1, bls_im2_instance1], [bls_im1_instance2, bls_im2_instance2], ...] to [bls_im1_instance1, bls_im2_instance1, bls_im1_instance2, bls_im2_instance2, ...] + block_sizes = sum(block_sizes, []) + return { + "input_ids": input_ids, + "media": media, + "media_config": {"image": {"block_sizes": block_sizes}, "video": {}, "speech": {}, "sound": {}}, + "labels": labels, + "attention_mask": attention_mask, + "media_meta": media_meta, + } diff --git a/llava/data/dataset.py b/llava/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..29c8b4c12ceca36f709b4ac87ab18fefae57df71 --- /dev/null +++ b/llava/data/dataset.py @@ -0,0 +1,1635 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import copy +import io +import json +import os +import os.path as osp +import random +import time +import warnings +from dataclasses import dataclass +from typing import Dict, Sequence +import math +import numpy as np +import PIL +import torch +import transformers +from PIL import Image, ImageFile +from torch.utils.data import Dataset, default_collate +from transformers import PreTrainedTokenizer +from transformers import AutoFeatureExtractor +import kaldiio +import llava.data.datasets_mixture as datasets_mixture +from llava import conversation as conversation_lib +from llava.constants import DEFAULT_SOUND_TOKEN,DEFAULT_SPEECH_TOKEN, IGNORE_INDEX +from llava.data.collate import DataCollator +from llava.mm_utils import ( + load_audio, + get_num_windows, + tokenizer_image_token, +) +from torchvision import transforms +from llava.train.args import DataArguments, TrainingArguments +from llava.train.sequence_parallel import ( + extract_local_from_list, + extract_local_input_ids, + extract_local_position_ids, + get_pg_manager, +) +from llava.utils.tokenizer import preprocess_conversation +# import torchaudio +from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, UniformClipSampler +import soundfile as sf +from librosa import resample as librosa_resample +import whisper +ImageFile.LOAD_TRUNCATED_IMAGES = True +PIL.Image.MAX_IMAGE_PIXELS = 1000000000 + +def int16_to_float32(x): + return (x / 32767.0).astype(np.float32) + + +def float32_to_int16(x): + x = np.clip(x, a_min=-1., a_max=1.) + return (x * 32767.).astype(np.int16) + + + +def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict: + is_multimodal = data_args.is_multimodal + if not is_multimodal: + return sources + + for source in sources: + concat_values = "".join([sentence["value"] for sentence in source]) + for sid, sentence in enumerate(source): + # In multimodal conversations, we automatically prepend '' at the start of the first sentence if it doesn't already contain one. + + if DEFAULT_SOUND_TOKEN in sentence["value"]: + sentence["value"] = sentence["value"].replace(DEFAULT_SOUND_TOKEN, f"{DEFAULT_SOUND_TOKEN}\n") + sentence["value"] = sentence["value"].replace(f"{DEFAULT_SOUND_TOKEN}\n\n", f"{DEFAULT_SOUND_TOKEN}\n") + if DEFAULT_SPEECH_TOKEN in sentence["value"]: + sentence["value"] = sentence["value"].replace(DEFAULT_SPEECH_TOKEN, f"{DEFAULT_SPEECH_TOKEN}\n") + sentence["value"] = sentence["value"].replace(f"{DEFAULT_SPEECH_TOKEN}\n\n", f"{DEFAULT_SPEECH_TOKEN}\n") + return sources + + +def preprocess_plain( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + # add end signal and concatenate together + conversations = [] + for source in sources: + assert len(source) == 2 + assert DEFAULT_IMAGE_TOKEN in source[0]["value"] + source[0]["value"] = DEFAULT_IMAGE_TOKEN + conversation = source[0]["value"] + source[1]["value"] + conversation_lib.default_conversation.sep + conversations.append(conversation) + # tokenize conversations + input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations] + targets = copy.deepcopy(input_ids) + for target, source in zip(targets, sources): + tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer)) + target[:tokenized_len] = IGNORE_INDEX + + return dict(input_ids=input_ids, labels=targets) + + +def preprocess( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False, + no_system_prompt: bool = False, +) -> Dict: + if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: + return preprocess_plain(sources, tokenizer) + return default_collate( + [ + preprocess_conversation(conversation, tokenizer, no_system_prompt=no_system_prompt) + for conversation in sources + ] + ) + + +class LazySupervisedDataset(Dataset): + """Dataset for supervised fine-tuning. + This class is originally implemented by the LLaVA team and modified by + Ji Lin and Haotian Tang. + """ + + def __init__( + self, + data_path: str, + image_folder: str, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments, + training_args: TrainingArguments, + ): + super().__init__() + try: + with open(data_path) as fp: + list_data_dict = json.load(fp) + except: + with open(data_path) as fp: + list_data_dict = [json.loads(q) for q in fp] + + # rank0_print("Formatting inputs...Skip in lazy mode") + print("Formatting inputs...Skip in lazy mode") + self.tokenizer = tokenizer + self.list_data_dict = list_data_dict + self.data_args = data_args + self.image_folder = image_folder + self.wav_processor = AutoFeatureExtractor.from_pretrained('/lustre/fsw/portfolios/adlr/users/sreyang/flamingo_v2/NV-Whisper') + + def __len__(self): + return len(self.list_data_dict) + + @property + def lengths(self): + length_list = [] + for sample in self.list_data_dict: + img_tokens = 128 if "image" in sample else 0 + length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens) + return length_list + + @property + def modality_lengths(self): + length_list = [] + for sample in self.list_data_dict: + if 'duration' in sample.keys(): + duration = sample["duration"] + else: + duration = 10. + try: + cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"]) + int(math.ceil(duration * 25)) + cur_len = cur_len if "sound" in sample else -cur_len + length_list.append(cur_len) + except: + try: + cur_len = 0 + int(math.ceil(duration * 25)) + cur_len = cur_len if "sound" in sample else -cur_len + length_list.append(cur_len) + except: + cur_len = 0 + int(math.ceil(10. * 25)) + cur_len = cur_len if "sound" in sample else -cur_len + length_list.append(cur_len) + return length_list + + @staticmethod + def _load_sound(sound_file, wav_processor, sample_rate=16000, window_length=30.0, window_overlap=0.0, max_num_window=3, audio_start = 0.0): + if sound_file is None: + return None + window_length = int(window_length * sample_rate) + window_overlap = int(window_overlap * sample_rate) + max_num_window = int(max_num_window) + duration = max_num_window * (window_length - window_overlap) + window_overlap + + sound_outputs = [] + audio_feature_masks = [] + audio_embed_masks = [] + + try: + sound_filename = str.split(sound_file, '/')[-1] + if '.ark' in sound_filename: + sound = kaldiio.load_mat(sound_file) + audio_data = sound[1] + audio_data=audio_data.astype(np.float16) + else: + audio_data = load_audio(sound_file, sample_rate, duration, audio_start) # already cuts to max duration + T = len(audio_data) + audio_data = audio_data.reshape(1, -1) + num_windows, full_length = get_num_windows(T, sample_rate, max_num_window) + + audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float() + + for i in range(num_windows): + audio_embed_mask = torch.zeros(750) + start = i * (window_length - window_overlap) + audio_data_tensor_this = audio_data_tensor[:, start:start+window_length] + orig_length = audio_data_tensor_this.shape[1] + audio_data_tensor_this = wav_processor(audio_data_tensor_this.cpu().numpy(), sampling_rate=sample_rate, return_tensors="pt") #.squeeze(0) text="dummy", audios=audio_data_tensor_this, return_tensors="pt") # + sound_outputs.append(audio_data_tensor_this["input_features"]) + # calculate the mask for the input melspec to Whisper + melspec_frames_this_window = int(math.ceil(orig_length / 160)) + feature_attention_mask = torch.zeros(3000, dtype=torch.int32) + feature_attention_mask[:melspec_frames_this_window] = 1 + audio_feature_masks.append(feature_attention_mask.unsqueeze(0)) + # calculate the mask for the output embedding for use in AF2 + conv_lengths = (melspec_frames_this_window - 1) // 2 + 1 + output_embedding_lengths = (conv_lengths - 2) // 2 + 1 + audio_embed_mask[:output_embedding_lengths] = 1 + audio_embed_masks.append(audio_embed_mask) + except: + print('error loading file', sound_file) + sound_outputs.append(torch.zeros(1,128,3000)) + audio_feature_masks.append(torch.zeros(1,3000, dtype=torch.int32)) + audio_embed_masks.append(torch.zeros(750)) + + return torch.stack(sound_outputs, dim=0), torch.stack(audio_feature_masks, dim=0), torch.stack(audio_embed_masks, dim=0) + + @staticmethod + def _load_speech(speech_path,sample_rate=16000): + if speech_path is None: + return None + + speech_outputs = [] + try: + speech = whisper.load_audio(speech_path) + speech = whisper.pad_or_trim(speech) + mel = whisper.log_mel_spectrogram(speech) + speech_outputs.append(mel.unsqueeze(0)) + except: + speech_outputs.append(torch.zeros(1,80,3000)) + return torch.stack(speech_outputs, dim=0) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + sources = self.list_data_dict[i] + if isinstance(i, int): + sources = [sources] + assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME + + import re + if "sound" in self.list_data_dict[i]: + # chat data loading + if isinstance(self.list_data_dict[i]["sound"],list): + sound_files = self.list_data_dict[i]["sound"] + conversations_raw = self.list_data_dict[i]["conversations"] + + # Step 1: Extract tags in order of appearance + sound_tag_pattern = re.compile(r"") + ordered_sound_tags = [] + + for turn in conversations_raw: + tags = sound_tag_pattern.findall(turn["value"]) + ordered_sound_tags.extend([f"" for tag in tags]) + + # Step 2: Load sound tensors in the order of tags + sound_tensor = [] + audio_feature_masks = [] + audio_embed_masks = [] + sound_token_map = {} + + for tag in ordered_sound_tags: + idx = int(tag.split('-')[1][:-1]) + if tag not in sound_token_map: + this_sound_tensor, af_mask, ae_mask = self._load_sound(sound_file, self.wav_processor, max_num_window=self.data_args.audio_frames) + this_sound_tensor = this_sound_tensor.squeeze(1) # (windows x 750 x 2048) + sound_token_map[tag] = ("\n" * this_sound_tensor.shape[0]).rstrip() + sound_tensor.append(this_sound_tensor) + audio_feature_masks.append(af_mask) + audio_embed_masks.append(ae_mask) + else: + # If already loaded, still append to match sequence + this_sound_tensor, af_mask, ae_mask = self._load_sound(sound_file, self.wav_processor, max_num_window=self.data_args.audio_frames) + this_sound_tensor = this_sound_tensor.squeeze(1) + sound_tensor.append(this_sound_tensor) + audio_feature_masks.append(af_mask) + audio_embed_masks.append(ae_mask) + + + # Process conversations and inject sound markers + conversation = [] + for turn in conversations_raw: + role = turn["from"] + value = turn["value"] + + # Replace any tag with corresponding repeated \n + for tag, sound_token in sound_token_map.items(): + value = value.replace(tag, sound_token) + + conversation.append({ + "from": role, + "value": value.rstrip() + }) + + sources = [conversation] + sound_tensor = torch.cat(sound_tensor, dim=0) + audio_feature_masks = torch.cat(audio_feature_masks, dim=0) + audio_embed_masks = torch.cat(audio_embed_masks, dim=0) + else: + sound_file = self.list_data_dict[i]["sound"] + question = str(self.list_data_dict[i]["conversations"][0]["value"].rstrip()) + answer = str(self.list_data_dict[i]["conversations"][1]["value"]).rstrip() + question = question.replace("\n", "").replace("\n", "").replace("", "") + question = question.replace("\n", "").replace("\n", "").replace("", "") + question = question.replace("\n", "").replace("\n", "").replace("", "") + question = question.replace("\n", "").replace("\n", "").replace("", "") + sound_tensor, audio_feature_masks, audio_embed_masks = self._load_sound(sound_file, self.wav_processor, max_num_window=self.data_args.audio_frames) + sound_tensor=sound_tensor.squeeze(1) # squeeze the irrelevant dimension which was caused due to processor getting 1 batch for processing --> (windows x 750 x 2048) + question = "\n" * sound_tensor.shape[0] + question + conversation = [ + {"from": "human", "value": question}, + {"from": "gpt", "value": answer}, + ] + + sources = [conversation] + data_dict = preprocess( + sources, + self.tokenizer, + has_image=( + "sound" in self.list_data_dict[i] + ), + ) + if isinstance(i, int): + data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) + + if "sound" in self.list_data_dict[i]: + data_dict["sound"] = sound_tensor + data_dict["sound_feature_masks"] = audio_feature_masks + data_dict["sound_embed_masks"] = audio_embed_masks + if "speech" in self.list_data_dict[i]: + data_dict["speech"] = speech_tensor + + return data_dict + + +class LazyMMC4Dataset(Dataset): + """Dataset for supervised fine-tuning. + This class is implemented by Ji Lin and Haotian Tang.""" + + def __init__( + self, + data_path: str, + image_folder: str, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments, + training_args: TrainingArguments, + image_following_text_only=False, + text_only=False, + ): + super().__init__() + + import pickle + + n_samples = [] + # actually shards and stats info + n_shards = len(os.listdir(data_path)) // 2 + # n_shards = 100 + count_info_list = sorted([f for f in os.listdir(data_path) if f.endswith(".count")])[:n_shards] + n_samples = [int(open(os.path.join(data_path, f)).read().strip()) for f in count_info_list] + + print("total MMC4 samples", sum(n_samples)) # 10,881,869 + + PROCESS_GROUP_MANAGER = get_pg_manager() + if PROCESS_GROUP_MANAGER is not None: + import torch.distributed as dist + + sequence_parallel_size = training_args.seq_parallel_size + else: + sequence_parallel_size = 1 + print("sequence_parallel_size", sequence_parallel_size) + rank = training_args.process_index // sequence_parallel_size # int(os.environ["RANK"]) + world_size = training_args.world_size // sequence_parallel_size # int(os.environ["WORLD_SIZE"]) + shared_size = n_shards // world_size + + gpu_samples = [sum(n_samples[i * shared_size : (i + 1) * shared_size]) for i in range(world_size)] + self.n_samples = min(gpu_samples) * world_size # total size + self.idx_offset = rank * min(gpu_samples) + shard_start, shard_end = rank * shared_size, (rank + 1) * shared_size + print(f" * loading data from shard {shard_start}-{shard_end}") + + shard_names = [d.replace(".count", ".pkl") for d in count_info_list] + shard_names = shard_names[shard_start:shard_end] + + full_data_list = [] + # now load data + for shard_name in shard_names: + # load shard + with open(os.path.join(data_path, shard_name), "rb") as f: + data_list = pickle.load(f) + + full_data_list.extend(data_list) + + print(f"* loaded totally {len(full_data_list)} samples") + + self.data_list = full_data_list + + self.tokenizer = tokenizer + self.data_args = data_args + self.image_folder = image_folder + + self.image_following_text_only = image_following_text_only + self.text_only = text_only + + def __len__(self): + # return len(self.data_list) + return self.n_samples + + @property + def modality_lengths(self): + # Estimate the number of tokens after tokenization, used for length-grouped sampling + length_list = [] + for info in self.data_list: + num_images = min(6, len(info["image_info"])) + sentences = [info["text_list"][x["matched_text_index"]] for x in info["image_info"][:num_images]] + # The unit of cur_len is "words". We assume 1 word = 2 tokens. + cur_len = num_images * self.num_image_tokens // 2 + sum([len(x) for x in sentences]) + length_list.append(cur_len) + return length_list + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + info = self.data_list[i - self.idx_offset] + + sentences = info["text_list"] + # kentang-mit@: remove existing tokens in the sentences + for ix in range(len(sentences)): + # if this is an html tag, we still preserve its semantic meaning + sentences[ix] = sentences[ix].replace("", "") + sim_matrix = info["similarity_matrix"] # we do not use this... + + # convert images from base64 to PIL and filter based on image-text similarity + images, sentence_ixs = [], [] + if not self.text_only: + for sample_image, sim_vec in zip(info["image_info"], sim_matrix): + image_base64 = sample_image["image_base64"] + rawbytes = base64.b64decode(image_base64) + + sim_ix = sample_image["matched_text_index"] + # sim_ix = np.argmax(sim_vec) + # sim_score = sim_vec[sim_ix] + + # filter to images >= 5KB + # if len(rawbytes) // 1000 <= 5: + # continue + # if sim_score < 0.24: + # continue + image = Image.open(io.BytesIO(rawbytes)).convert("RGB") + + images.append(image) + sentence_ixs.append(sim_ix) + + # constrain max num 6 images + max_num_images = 6 + if len(images) > max_num_images: + images = images[:max_num_images] + sentence_ixs = sentence_ixs[:max_num_images] + + # reorder images according to text insertion + images = [images[iii] for iii in np.argsort(sentence_ixs)] + + # preprocess and tokenize text + for ix in sentence_ixs: + sentences[ix] = f"\n{sentences[ix]}" + + if self.image_following_text_only: + # use pad tokens to divide sentence pieces + text = self.tokenizer.pad_token.join(sentences) + else: + text = " ".join(sentences) + # whitespace cleanup + text = text.replace(" ", "").replace(" ", "") + text = f"{text}{self.tokenizer.eos_token}" # add eos token + + if len(images) > 0: + if self.data_args.image_aspect_ratio == "dynamic_s2": + images, block_sizes = dynamic_s2_process_images_and_prompt( + images, text, self.data_args, self.image_folder + ) + elif self.data_args.image_aspect_ratio == "dynamic": + images, text = dynamic_process_images_and_prompt( + images, text, self.data_args, self.image_folder, max_tiles=6 + ) + else: + images = torch.stack([process_image(image, self.data_args, self.image_folder) for image in images]) + + # the same size for all images, so we concat + # cur_token_len = ( + # images[0].shape[-2] // self.multimodal_cfg["patch_size"] + # ) * (images[0].shape[-1] // self.multimodal_cfg["patch_size"]) + # cur_token_len += self.multimodal_cfg["n_extra_patch"] + else: + images = None + # cur_token_len = 0 + + input_ids = tokenizer_image_token( + text, + self.tokenizer, + return_tensors="pt", + ) + + image_token_id = self.tokenizer.media_token_ids["image"] + + # now check the case where the last token is image patch token + if input_ids[-1] == image_token_id: # need to remove one last image + last_non_im_patch_indices = torch.where(input_ids != image_token_id)[0][-1] + 1 + input_ids = input_ids[:last_non_im_patch_indices] + + n_im_patch = (input_ids == image_token_id).sum().item() + + if self.data_args.image_aspect_ratio != "dynamic_s2": + images = images[:n_im_patch] + assert len(images) == n_im_patch, print(text, input_ids) + assert len(input_ids.shape) == 1, "Unexpected shape of 'input_ids' from MMC4." + input_ids = ( + torch.concat([torch.tensor([self.tokenizer.bos_token_id]), input_ids]) + if self.tokenizer.bos_token_id is not None and input_ids[0] != self.tokenizer.bos_token_id + else input_ids + ) + targets = input_ids.clone() + + if self.image_following_text_only: # keep only text after leading image token + # remove loss for any token before the first token + label_idx = 0 + while label_idx < targets.shape[-1] and targets[label_idx] != image_token_id: + targets[label_idx] = IGNORE_INDEX + label_idx += 1 + + pad_token = self.tokenizer.convert_tokens_to_ids([self.tokenizer.pad_token])[0] + + pad_token_idxs = torch.where(targets == pad_token)[0] + for pad_token_idx in pad_token_idxs: + token_idx = pad_token_idx + 1 + while token_idx < targets.shape[-1] and targets[token_idx] != image_token_id: + targets[token_idx] = IGNORE_INDEX + token_idx += 1 + # do not train on padding tokens + targets[targets == pad_token] = IGNORE_INDEX + + # mask image tokens is unnecessary for llava-1.5 + # targets[targets == IMAGE_TOKEN_INDEX] = IGNORE_INDEX + # print(input_ids.shape) + + data_dict = dict(input_ids=input_ids, labels=targets, image=images) + if self.data_args.image_aspect_ratio == "dynamic_s2": + data_dict["block_sizes"] = block_sizes + + return data_dict + + +class LazyCoyoDataset(Dataset): + """Dataset for supervised fine-tuning. + This class is implemented by Ji Lin and Haotian Tang.""" + + num_image_tokens = 576 + + def __init__( + self, + data_path: str, + image_folder: str, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments, + training_args: TrainingArguments, + # kentang-mit@: balance the total number of tokens for Coyo and MMC4. + n_samples_per_idx=4, + ): + super().__init__() + + import pickle + + n_samples = [] + # actually shards and stats info + n_shards = len(os.listdir(data_path)) // 2 + # n_shards = 100 + count_info_list = sorted([f for f in os.listdir(data_path) if f.endswith(".count")])[:n_shards] + n_samples = [int(open(os.path.join(data_path, f)).read().strip()) for f in count_info_list] + + print("total COYO samples", sum(n_samples)) + + PROCESS_GROUP_MANAGER = get_pg_manager() + if PROCESS_GROUP_MANAGER is not None: + import torch.distributed as dist + + sequence_parallel_size = training_args.seq_parallel_size + else: + sequence_parallel_size = 1 + print("sequence_parallel_size", sequence_parallel_size) + rank = training_args.process_index // sequence_parallel_size # int(os.environ["RANK"]) + world_size = training_args.world_size // sequence_parallel_size # int(os.environ["WORLD_SIZE"]) + shared_size = n_shards // world_size + + gpu_samples = [ + sum(n_samples[i * shared_size : (i + 1) * shared_size]) // n_samples_per_idx for i in range(world_size) + ] + self.n_samples = min(gpu_samples) * world_size # total size + self.idx_offset = rank * min(gpu_samples) + + shard_start, shard_end = rank * shared_size, (rank + 1) * shared_size + print(f" * loading data from shard {shard_start}-{shard_end}") + + shard_names = [d.replace(".count", ".pkl") for d in count_info_list] + shard_names = shard_names[shard_start:shard_end] + + full_data_list = [] + # now load data + for shard_name in shard_names: + # load shard + with open(os.path.join(data_path, shard_name), "rb") as f: + shard_data = pickle.load(f) + random.seed(42) + if "mmc4" in data_path: + random.shuffle(shard_data) # shuffle for MMC4cap only + full_data_list.extend(shard_data) + + print(f"* loaded totally {len(full_data_list)} samples") + + # now pack the samples into groups + n_groups = len(full_data_list) // n_samples_per_idx + full_data_list = [ + full_data_list[i : i + n_samples_per_idx] for i in range(0, len(full_data_list), n_samples_per_idx) + ] + if len(full_data_list[-1]) < n_samples_per_idx: + full_data_list = full_data_list[:-1] + assert len(full_data_list) == n_groups + print(f"split into {n_groups} groups") + + self.data_list = full_data_list + + self.tokenizer = tokenizer + self.data_args = data_args + self.image_folder = image_folder + + def __len__(self): + # return len(self.data_list) + return self.n_samples + + @property + def modality_lengths(self): + # Estimate the number of tokens after tokenization, used for length-grouped sampling + length_list = [] + for samples in self.data_list: + cur_len = sum([len(conv["text" if "text" in conv else "caption"].split()) for conv in samples]) + # The unit of cur_len is "words". We assume 1 word = 2 tokens. + cur_len = cur_len + len(samples) * self.num_image_tokens // 2 + length_list.append(cur_len) + return length_list + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + CONCAT_SAMPLES = False + info_list = self.data_list[i - self.idx_offset] + + text_list = [] + image_list = [] + + for sample in info_list: + caption_key = ( + "text" if "text" in sample else "caption" + ) # kentang-mit@: remove existing tokens in the sentences + # kentang-mit@: remove existing token. + # if this is an html tag, we still preserve its semantic meaning + sample[caption_key] = sample[caption_key].replace("", "") + text_list.append(DEFAULT_IMAGE_TOKEN + "\n" + sample[caption_key] + self.tokenizer.eos_token) + if "image" in sample: + image_base64 = sample["image"] + rawbytes = base64.b64decode(image_base64) + else: + rawbytes = sample["rawbytes"] + image = Image.open(io.BytesIO(rawbytes)).convert("RGB") + image_list.append(image) + + image_list = torch.stack([process_image(image, self.data_args, self.image_folder) for image in image_list]) + + if CONCAT_SAMPLES: + # into capcap... + text_list = "".join(text_list) + + input_ids = self.tokenizer( + text_list, + return_tensors="pt", + padding="longest", + max_length=self.tokenizer.model_max_length, + truncation=True, + ).input_ids # 4, seq_len + + input_ids = input_ids[0] + + else: + input_ids = [ + tokenizer_image_token( + prompt, + self.tokenizer, + return_tensors="pt", + ) + for prompt in text_list + ] + # print([x.shape[0] for x in input_ids], [len(x.split()) for x in text_list], [len(re.findall(r"]*>", x)) for x in text_list]) + + # input_ids = torch.nn.utils.rnn.pad_sequence( + # input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id + # ) + + targets = copy.deepcopy(input_ids) + for i in range(len(targets)): + targets[i][targets[i] == self.tokenizer.pad_token_id] = IGNORE_INDEX + + return dict(input_ids=input_ids, labels=targets, image=image_list) + + +class LazyWDSDataset(Dataset): + """Dataset for supervised fine-tuning. + This class is implemented by Ji Lin and Ligeng Zhu.""" + + def __init__( + self, + data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments, + image_folder: str, + training_args: TrainingArguments, + ): + super().__init__() + n_samples = [] + n_shards = len(os.listdir(data_path)) // 3 + for shard in range(n_shards): + with open(os.path.join(data_path, f"{shard:05d}_stats.json")) as f: + info = json.load(f) + n_samples.append(info["successes"]) + + # print(f"[DEBUG] {data_path} total samples", sum(n_samples)) # 10,881,869 + + PROCESS_GROUP_MANAGER = get_pg_manager() + if PROCESS_GROUP_MANAGER is not None: + import torch.distributed as dist + + sequence_parallel_size = training_args.seq_parallel_size + else: + sequence_parallel_size = 1 + print("sequence_parallel_size", sequence_parallel_size) + rank = training_args.process_index // sequence_parallel_size # int(os.environ["RANK"]) + world_size = training_args.world_size // sequence_parallel_size # int(os.environ["WORLD_SIZE"]) + shared_size = n_shards // world_size + print("rank", rank, "world_size", world_size, "shared_size", shared_size) + gpu_samples = [sum(n_samples[i * shared_size : (i + 1) * shared_size]) for i in range(world_size)] + self.n_samples = min(gpu_samples) * world_size # total size + self.idx_offset = rank * min(gpu_samples) + shard_start, shard_end = rank * shared_size, (rank + 1) * shared_size + print(f" * loading data from shard {shard_start}-{shard_end}") + + tar_list = [f"{shard_idx:05d}.tar" for shard_idx in range(shard_start, shard_end)] + + self.data_list = [] + t1 = time.time() + for tar in tar_list: + tmp_path = f"/tmp/ccs{tar}" + tar_path = os.path.join(data_path, tar) + + if PROCESS_GROUP_MANAGER is not None: + dist.barrier() + if PROCESS_GROUP_MANAGER.sp_rank == 0: + os.makedirs(tmp_path, exist_ok=True) + os.system(f"tar -xkf {tar_path} -C {tmp_path}") + dist.barrier() + else: + os.makedirs(tmp_path, exist_ok=True) + os.system(f"tar -xkf {tar_path} -C {tmp_path}") + + txt_list = [f for f in os.listdir(tmp_path) if f.endswith(".txt")] + + for txt in txt_list: + caption = open(os.path.join(tmp_path, txt)).read().strip() + image_path = os.path.join(tmp_path, txt.split(".")[0] + ".jpg") + self.data_list.append({"caption": caption, "image": image_path}) + t2 = time.time() + print(f"Loading done. Total time: {t2 - t1:.2f} seconds") + + self.tokenizer = tokenizer + self.data_args = data_args + self.image_folder = image_folder + + def __len__(self): + return self.n_samples + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + + # print("i", i, "idx_offset", self.idx_offset, "len", len(self.data_list)) + info = self.data_list[i - self.idx_offset] + caption, image_path = info["caption"], info["image"] + + rand_prompt = "\n" + sources = [ + { + "image": image_path, + "conversations": [ + {"from": "human", "value": rand_prompt}, + {"from": "gpt", "value": caption}, + ], + } + ] + + # one example of sources + # [{'id': 'GCC_train_001738742', 'image': 'GCC_train_001738742.jpg', 'conversations': [{'from': 'human', 'value': 'Provide a brief description of the given image.\n'}, {'from': 'gpt', 'value': 'a sketch of an ostrich'}]}] + if "image" in sources[0]: + image = process_image(sources[0]["image"], self.data_args, self.image_folder) + image = torch.unsqueeze(image, dim=0) + # now random pick some context samples for training + if hasattr(self.data_args, "num_shots"): + if self.data_args.num_shots > 0: + raise NotImplementedError + else: + raise NotImplementedError + + data_dict = preprocess([sources[0]["conversations"]], self.tokenizer, has_image=True) + + if isinstance(i, int): + data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) + + # image exist in the data + if image is not None: + data_dict["image"] = image + else: + raise NotImplementedError + + return data_dict + + +class LazyCCSWebDataset(Dataset): + """Dataset for supervised fine-tuning. + This class is implemented by Ligeng Zhu.""" + + def __init__( + self, + data_path: str, + image_folder: str, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments, + training_args: TrainingArguments, + ): + super().__init__() + t1 = time.time() + + from llava.data.simple_vila_webdataset import VILAWebDataset + + print("[DEBUG] ", osp.abspath(data_path)) + self.dataset = VILAWebDataset(data_path=osp.abspath(data_path)) + + t2 = time.time() + print(f"Loading done. Total time: {t2 - t1:.2f} seconds") + + self.tokenizer = tokenizer + self.data_args = data_args + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + # info = self.data_list[i - self.idx_offset] + # caption, image_path = info["caption"], info["image"] + info = self.dataset[i] + if ".jpg" in info: + caption, image_path = info[".txt"], info[".jpg"] + elif ".png" in info: + caption, image_path = info[".txt"], info[".png"] + elif ".webp" in info: + caption, image_path = info[".txt"], info[".webp"] + elif ".bmp" in info: + caption, image_path = info[".txt"], info[".bmp"] + elif ".tiff" in info: + caption, image_path = info[".txt"], info[".tiff"] + else: + print(info.keys()) + print(info) + raise KeyError + + caption = caption.replace("", "") + if isinstance(image_path, io.BytesIO): + image_path = Image.open(image_path).convert("RGB") + + if not isinstance(image_path, PIL.Image.Image): + print(image_path) + print(info.keys()) + print(type(image_path)) + raise NotImplementedError + + rand_prompt = "\n" + sources = [ + { + "image": image_path, + "conversations": [ + {"from": "human", "value": rand_prompt}, + {"from": "gpt", "value": caption}, + ], + } + ] + + # one example of sources + # [{'id': 'GCC_train_001738742', 'image': 'GCC_train_001738742.jpg', 'conversations': [{'from': 'human', 'value': 'Provide a brief description of the given image.\n'}, {'from': 'gpt', 'value': 'a sketch of an ostrich'}]}] + if "image" in sources[0]: + image = process_image(sources[0]["image"], self.data_args, image_folder=None) + image = torch.unsqueeze(image, dim=0) + # now random pick some context samples for training + if hasattr(self.data_args, "num_shots"): + if self.data_args.num_shots > 0: + raise NotImplementedError + else: + raise NotImplementedError + + data_dict = preprocess([sources[0]["conversations"]], self.tokenizer, has_image=True) + + if isinstance(i, int): + data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) + + # image exist in the data + if image is not None: + data_dict["image"] = image + else: + raise NotImplementedError + + return data_dict + + +from functools import lru_cache + + +@lru_cache(maxsize=16) +def lru_json_load(fpath): + with open(fpath) as fp: + return json.load(fp) + + +class LazyCoyoWebDataset(Dataset): + """Dataset for supervised fine-tuning. + This class is implemented by Ligeng Zhu.""" + + num_image_tokens = 576 + + def __init__( + self, + data_path: str, + image_folder: str, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments, + training_args: TrainingArguments, + # kentang-mit@: balance the total number of tokens for Coyo and MMC4. + n_samples_per_idx=4, + ): + super().__init__() + + from llava.data.simple_vila_webdataset import VILAWebDataset + + print("[DEBUG] ", osp.abspath(data_path)) + self.dataset = VILAWebDataset(data_path=osp.abspath(data_path), meta_path=data_args.meta_path) + + if data_args.start_idx >= 0 and data_args.end_idx >= 0: + # Ligeng: support slicing for ablate different subsets. + total = len(self.dataset) + start_idx = int(total * data_args.start_idx) + end_idx = int(total * data_args.end_idx) + print(f"loading subset from {start_idx} to {end_idx}, total {total}") + self.dataset = torch.utils.data.Subset(self.dataset, range(start_idx, end_idx)) + + # For caption choice, + # if None: use original caption + # if a folder path: use specified caption to override original one (choice1) + # if a folder path: use specified caption and concat with original one (choice2) + self.caption_choice = None + self.caption_choice_2 = None + self.data_path = data_path + + if data_args.caption_choice is not None: + self.caption_choice = data_args.caption_choice + print("[recap] Override coyo caption using ", self.caption_choice) + + if data_args.caption_choice_2 is not None: + self.caption_choice_2 = data_args.caption_choice_2 + print("[recapv2] Override coyo caption using ", self.caption_choice_2) + + print("total samples", len(self.dataset)) + PROCESS_GROUP_MANAGER = get_pg_manager() + if PROCESS_GROUP_MANAGER is not None: + import torch.distributed as dist + + sequence_parallel_size = training_args.seq_parallel_size + sequence_parallel_rank = PROCESS_GROUP_MANAGER.sp_rank + else: + sequence_parallel_size = 1 + print("sequence_parallel_size", sequence_parallel_size) + rank = ( + training_args.process_index // sequence_parallel_size if "RANK" in os.environ else 2 + ) # int(os.environ["RANK"]) + world_size = ( + training_args.world_size // sequence_parallel_size if "WORLD_SIZE" in os.environ else 32 + ) # int(os.environ["WORLD_SIZE"]) + print( + "rank", + rank, + "world_size", + world_size, + ) + + self.n_samples_per_idx = n_samples_per_idx + # self.n_samples = len(self.dataset) // n_samples_per_idx + self.tokenizer = tokenizer + self.data_args = data_args + + def __len__(self): + return len(self.dataset) // self.n_samples_per_idx + + @property + def modality_lengths(self): + # Estimate the number of tokens after tokenization, used for length-grouped sampling + length_list = [] + for samples in self.data_list: + cur_len = sum([len(conv["text" if "text" in conv else "caption"].split()) for conv in samples]) + # The unit of cur_len is "words". We assume 1 word = 2 tokens. + cur_len = cur_len + len(samples) * self.num_image_tokens // 2 + length_list.append(cur_len) + return length_list + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + CONCAT_SAMPLES = False + # info_list = self.dataset[i - self.idx_offset] + + begin_idx, end_idx = ( + i * self.n_samples_per_idx, + (i + 1) * self.n_samples_per_idx, + ) + end_idx = min(end_idx, len(self.dataset)) + + text_list = [] + image_list = [] + + for idx in range(begin_idx, end_idx): + info = self.dataset[idx] + if ".jpg" in info: + caption, image_path = info[".txt"], info[".jpg"] + elif ".png" in info: + caption, image_path = info[".txt"], info[".png"] + elif ".webp" in info: + caption, image_path = info[".txt"], info[".webp"] + elif ".bmp" in info: + caption, image_path = info[".txt"], info[".bmp"] + elif ".tiff" in info: + caption, image_path = info[".txt"], info[".tiff"] + else: + print(info.keys()) + print(info) + raise KeyError + + if self.caption_choice is not None: + # load new captions + shard = info["__shard__"] + url = info[".json"]["url"] + tar_name = osp.relpath(osp.realpath(shard), osp.realpath(self.data_path)) + # tar_name = osp.dirname(shard) + shard_json_path = osp.join(self.caption_choice, tar_name + ".json") + try: + shard_json = lru_json_load(shard_json_path) + try: + caption = shard_json[url]["output"] + except KeyError: + print(f"{url} not in caption. fallback to original caption temporarially") + except: + print(f"shard_json_path {shard_json_path} not found. fallback to original caption temporarially") + caption = caption.replace("", "") + text_list.append(DEFAULT_IMAGE_TOKEN + caption + self.tokenizer.eos_token) + + if isinstance(image_path, io.BytesIO): + image_path = Image.open(image_path).convert("RGB") + + if not isinstance(image_path, PIL.Image.Image): + print(image_path) + print(info.keys()) + print(type(image_path)) + raise NotImplementedError + + image_list.append(image_path) + + # image_list = torch.stack([process_image(image, self.data_args, image_folder=None) for image in image_list]) + # NOTE(fix by ligeng) + # now image_list should return a list of image tensor where each has a dimension of (1, c, h, w) + image_list = [process_image(image, self.data_args, image_folder=None).unsqueeze(0) for image in image_list] + + if CONCAT_SAMPLES: + # into capcap... + text_list = "".join(text_list) + + input_ids = self.tokenizer( + text_list, + return_tensors="pt", + padding="longest", + max_length=self.tokenizer.model_max_length, + truncation=True, + ).input_ids # 4, seq_len + + input_ids = input_ids[0] + else: + input_ids = [ + tokenizer_image_token( + prompt, + self.tokenizer, + return_tensors="pt", + ) + for prompt in text_list + ] + input_ids = [ + ( + torch.concat([torch.tensor([self.tokenizer.bos_token_id]), input_ids_i]) + if input_ids_i[0] != self.tokenizer.bos_token_id + else input_ids_i + ) + for input_ids_i in input_ids + ] + + targets = copy.deepcopy(input_ids) + for i in range(len(targets)): + targets[i][targets[i] == self.tokenizer.pad_token_id] = IGNORE_INDEX + + return dict(input_ids=input_ids, labels=targets, image=image_list) + + +class LazyVideoWebDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__( + self, + data_path: str, + image_folder: str, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments, + training_args: TrainingArguments, + # cache_path: str, + # n_samples_per_idx=4, + ): + super().__init__() + + # from llava.data.simple_video_dataset import SimpleVideoDataset + + from llava.data.simple_vila_webdataset import VILAWebDataset + + print("[DEBUG] ", osp.abspath(data_path)) + self.dataset = VILAWebDataset( + data_path=osp.abspath(data_path), + meta_path=f"{osp.abspath(data_path)}/wids-meta.json", + # cache_dir=cache_path, + ) + + # None: use original caption + # Folder path: use original caption + self.caption_choice = None + self.data_path = data_path + + if data_args.caption_choice is not None: + self.caption_choice = data_args.caption_choice + print("[recap] Override LazyVideo caption using ", self.caption_choice) + + print("total samples", len(self.dataset)) + # InternVid: TODO + PROCESS_GROUP_MANAGER = get_pg_manager() + if PROCESS_GROUP_MANAGER is not None: + import torch.distributed as dist + + sequence_parallel_size = training_args.seq_parallel_size + sequence_parallel_rank = PROCESS_GROUP_MANAGER.sp_rank + else: + sequence_parallel_size = 1 + print("sequence_parallel_size", sequence_parallel_size) + rank = ( + training_args.process_index // sequence_parallel_size if "RANK" in os.environ else 2 + ) # int(os.environ["RANK"]) + world_size = ( + training_args.world_size // sequence_parallel_size if "WORLD_SIZE" in os.environ else 32 + ) # int(os.environ["WORLD_SIZE"]) + print( + "rank", + rank, + "world_size", + world_size, + ) + self.rank = rank + # rank = int(os.environ["RANK"]) if "RANK" in os.environ else 2 + # world_size = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 32 + + self.tokenizer = tokenizer + self.data_args = data_args + + self.missing_uids = set() + + def __len__(self): + return len(self.dataset) + + @property + def modality_lengths(self): + # Estimate the number of tokens after tokenization, used for length-grouped sampling + length_list = [] + for samples in self.data_list: + cur_len = sum([len(conv["text" if "text" in conv else "caption"].split()) for conv in samples]) + # The unit of cur_len is "words". We assume 1 word = 2 tokens. + cur_len = cur_len + len(samples) * self.num_image_tokens // 2 + length_list.append(cur_len) + return length_list + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + ADD_TEXT_PROMPT = False + num_video_frames = self.data_args.num_video_frames if hasattr(self.data_args, "num_video_frames") else 8 + loader_fps = self.data_args.fps if hasattr(self.data_args, "fps") else 0.0 + + info = self.dataset[i] + + caption = "" + # print(info) + if ".mp4" in info: + caption, video_path = info[".txt"], info[".mp4"] + else: + video_path = None + caption = "Empty video." + + images, frames_loaded, _ = LazySupervisedDataset._load_video( + video_path, num_video_frames, loader_fps, self.data_args + ) + + if frames_loaded == 0: + caption = "Empty video." + + if self.caption_choice is not None: + shard = info["__shard__"] + uuid = osp.join(info["__shard__"], info["__key__"]) + url = info["__key__"] + tar_name = osp.basename(info["__shard__"]) + + try: + shard_json_path = osp.join(self.caption_choice, tar_name.replace(".tar", ".json")) + shard_json = lru_json_load(shard_json_path) + caption = shard_json[url]["summary"]["output"] + except (KeyError, FileNotFoundError, json.decoder.JSONDecodeError): + if uuid not in self.missing_uids: + print("override caption not found for ", uuid) + self.missing_uids.add(uuid) + + # print(f"[DEBUG {uuid}]", caption) + + frames_loaded_successfully = len(images) + if caption is None: + caption = "" + prompt = "\n" * frames_loaded_successfully + caption + image_tensor = torch.stack([process_image(image, self.data_args, None) for image in images]) + + input_ids = tokenizer_image_token( + prompt, + self.tokenizer, + return_tensors="pt", + ) + targets = copy.deepcopy(input_ids) + data_dict = dict(input_ids=input_ids, labels=targets, image=image_tensor) + + return data_dict + + +class DataCollatorForSupervisedDatasetSeqParallel: + """Collate examples for supervised fine-tuning. + This class is originally implemented by the LLaVA team and + modified by Haotian Tang.""" + + def __init__( + self, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments, + training_args: TrainingArguments, + sp_degree: int, + sp_rank: int, + ring_degree: int, + ring_type: str, + ): + self.tokenizer = tokenizer + self.data_args = data_args + self.training_args = training_args + self.sp_degree = sp_degree + self.sp_rank = sp_rank + self.ring_degree = ring_degree + self.ring_type = ring_type + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels, images = [], [], [] + image_token_id = self.tokenizer.media_token_ids["image"] + video_token_id = self.tokenizer.media_token_ids["video"] + + for instance in instances: + if not isinstance(instance["input_ids"], list): + input_ids.append(instance["input_ids"]) + else: + input_ids += instance["input_ids"] + if not isinstance(instance["labels"], list): + labels.append(instance["labels"]) + else: + labels += instance["labels"] + # Note (kentang-mit@: we do not directly push tensors to + # images, but list of tensors. + if "video" in instance: + instance["image"] = torch.cat(instance["video"]) + video_id_pos = torch.where(input_ids[-1] == video_token_id)[0][0] + replace_ids = torch.Tensor( + ([image_token_id] + self.tokenizer.encode("\n")) * instance["image"].shape[0], + device=input_ids[-1].device, + ) + input_ids[-1] = torch.cat( + [input_ids[-1][:video_id_pos], replace_ids, input_ids[-1][video_id_pos + 1 :]] + ).to(input_ids[-1].dtype) + labels[-1] = torch.cat( + [ + labels[-1][:video_id_pos], + torch.Tensor([IGNORE_INDEX] * instance["image"].shape[0] * 2), + labels[-1][video_id_pos + 1 :], + ] + ).to(labels[-1].dtype) + instance.pop("video") + + if "image" in instance: + cur_image = instance["image"] + assert len(cur_image.shape) == 4 + # n_images, 3, size, size + if cur_image.shape[0] == 0: + warnings.warn("loaded one sample without images.") + if not isinstance(instance["input_ids"], list): + # datasets other than coyo, not packing >1 samples together + images.append(cur_image) + else: + # coyo-like datasets + images.extend(cur_image.chunk(cur_image.size(0), 0)) + else: + warnings.warn("loaded one sample without images.") + images.append([]) + # kentang-mit@: we need to make sure these two lists have + # the same length. We will use input_ids to filter out images corresponding + # to truncated tokens later. + + max_num_images = max([len(_images) for _images in images]) + for _images, _input_ids in zip(images, input_ids): + assert ( + len(_images) == (_input_ids == image_token_id).sum().item() + ), f"Number mismatch between images and placeholder image tokens in 'len(_images) == (_input_ids == image_token_id).sum().item()'.\ + Expect to have {len(_images)} images but only found {(_input_ids == image_token_id).sum().item()} images in tokens. \ + Error input_ids: {_input_ids} {self.tokenizer.decode([x if x != -200 else 200 for x in _input_ids])}" + + NUM_TOKENS_PER_IMAGE = self.data_args.num_image_tokens + if hasattr(self.data_args.image_processor, "crop_size"): + crop_size = self.data_args.image_processor.crop_size + else: + crop_size = self.data_args.image_processor.size + + # Init the padding sample + seq_id = 0 + while seq_id < len(input_ids): + # Skip the samples without images + dummy_image = torch.ones((1, 3, crop_size["height"], crop_size["width"]), device=input_ids[seq_id].device) + # dummy input_ids include one bos, one image token, and one eos + dummy_input_ids = torch.zeros_like(input_ids[seq_id][:3]) + dummy_input_ids[0] = self.tokenizer.bos_token_id + dummy_input_ids[1] = image_token_id + dummy_input_ids[2] = self.tokenizer.eos_token_id + dummy_labels = copy.deepcopy(dummy_input_ids) + dummy_labels[:2] = IGNORE_INDEX + dummy_seqlen = NUM_TOKENS_PER_IMAGE + 2 # TODO: Check the hard coding of 2 + dummy_position_ids = torch.arange(start=0, end=dummy_seqlen, dtype=torch.int32) + break + + # Sort with the real length of the sequence + combined = sorted( + zip(input_ids, labels, images), + key=lambda x: len(x[2]) * (NUM_TOKENS_PER_IMAGE - 1) + x[0].size(-1), + reverse=True, # Start Packing from the sequence with most images. + ) + sorted_ids, sorted_labels, sorted_images = zip(*combined) + sorted_ids, sorted_labels, sorted_images = list(sorted_ids), list(sorted_labels), list(sorted_images) + max_seq_length = self.tokenizer.model_max_length # len(sorted_ids[0]) + max_sample_len = 0 + + batches = [] + label_batches = [] + position_ids = [] + batch_images = [] + seqlens_in_batch = [] + + i = 0 + while i < len(sorted_ids): + current_batch = torch.tensor([], dtype=torch.int32) + current_label_batch = torch.tensor([], dtype=torch.int32) + current_position_ids = torch.tensor([], dtype=torch.int32) + current_batch_images = [] + current_num_images = 0 + current_len = 0 + current_num_samples = 0 + + # Pack a few samples into one sample + while i < len(sorted_ids): + num_images = (sorted_ids[i] == image_token_id).sum().item() + num_image_tokens_added = num_images * (NUM_TOKENS_PER_IMAGE - 1) + num_incoming_tokens = sorted_ids[i].size(-1) + num_image_tokens_added + + # Handle RingAttn_Varlen which requires `seqlens_in_batch` should be divisible by `ring_degree` + if self.ring_degree > 1: + RING_PAD_TOKEN_INDEX = 2 + if self.ring_type == "ring_varlen": + if num_incoming_tokens % self.sp_degree != 0: + pad_len = self.sp_degree - num_incoming_tokens % self.sp_degree + num_incoming_tokens += pad_len + # pad `input_ids` + pad_tensor = torch.full( + (pad_len,), RING_PAD_TOKEN_INDEX, dtype=sorted_ids[i].dtype, device=sorted_ids[i].device + ) + sorted_ids[i] = torch.cat([sorted_ids[i], pad_tensor]) + + # pad `label` + pad_label_tensor = torch.full( + (pad_len,), IGNORE_INDEX, dtype=sorted_labels[i].dtype, device=sorted_labels[i].device + ) + sorted_labels[i] = torch.cat([sorted_labels[i], pad_label_tensor]) + elif self.ring_type == "zigzag_ring_varlen": + self.zigzag_sp_degree = self.sp_degree * 2 + if num_incoming_tokens % self.zigzag_sp_degree != 0: + pad_len = self.zigzag_sp_degree - num_incoming_tokens % self.zigzag_sp_degree + num_incoming_tokens += pad_len + # pad `input_ids` + pad_tensor = torch.full( + (pad_len,), RING_PAD_TOKEN_INDEX, dtype=sorted_ids[i].dtype, device=sorted_ids[i].device + ) + sorted_ids[i] = torch.cat([sorted_ids[i], pad_tensor]) + + # pad `label` + pad_label_tensor = torch.full( + (pad_len,), IGNORE_INDEX, dtype=sorted_labels[i].dtype, device=sorted_labels[i].device + ) + sorted_labels[i] = torch.cat([sorted_labels[i], pad_label_tensor]) + else: + raise ValueError(f"Invalid ring_type: {self.ring_type}") + + if num_incoming_tokens > max_seq_length: + print( + f"Warning: Skipping one packed sample with {num_incoming_tokens} tokens,\ + please consider increase max seq len {max_seq_length}." + ) + i += 1 + continue + + if ( + (current_num_images == 0) + or (current_num_images < self.sp_degree) + or ( + (current_num_images + num_images <= max_num_images) + and (current_len + num_incoming_tokens <= max_sample_len) + ) + ) and (current_len + num_incoming_tokens <= max_seq_length): + current_num_images += num_images + current_len += num_incoming_tokens + current_num_samples += 1 + current_position_ids = torch.cat( + (current_position_ids, torch.arange(start=0, end=num_incoming_tokens)), dim=0 + ) + current_batch = torch.cat((current_batch, sorted_ids[i]), dim=0) + sorted_labels[i][0] = IGNORE_INDEX + current_label_batch = torch.cat((current_label_batch, sorted_labels[i]), dim=0) + seqlens_in_batch.append(num_incoming_tokens) + current_batch_images.extend(sorted_images[i]) + i += 1 + assert current_num_images == len(current_batch_images) + else: + break + + # Padding the sample with the dummy image sample, if there are no enough images + MAX_RETRY = self.sp_degree + num_retry = 0 + while current_num_images < self.sp_degree and current_len < max_seq_length and num_retry <= MAX_RETRY: + current_num_images += dummy_image.size(0) + current_len += dummy_seqlen + current_num_samples += 1 + current_position_ids = torch.cat((current_position_ids, dummy_position_ids), dim=0) + current_batch = torch.cat((current_batch, dummy_input_ids), dim=0) + current_label_batch = torch.cat((current_label_batch, dummy_labels), dim=0) + seqlens_in_batch.append(dummy_seqlen) + current_batch_images.extend(dummy_image) + # We pad from left side to ensure correct grad flow + # current_batch = torch.cat((dummy_input_ids, current_batch), dim=0) + # current_label_batch = torch.cat((dummy_labels, current_label_batch), dim=0) + # seqlens_in_batch.insert(0, dummy_seqlen) + # current_batch_images = torch.cat((dummy_image, current_batch_images), dim=0) + num_retry += 1 + + # Drop the samples that do not have enough images + if current_num_images < self.sp_degree: + print(f"Warning: Skipping one packed sample with {current_num_images} images") + seqlens_in_batch = seqlens_in_batch[:-current_num_samples] + continue + + max_sample_len = max(max_sample_len, current_len) + batches.append(current_batch) + label_batches.append(current_label_batch) + position_ids.append(current_position_ids) + batch_images.append(current_batch_images) + + try: + assert current_num_images == len(torch.where(current_batch == image_token_id)[0].tolist()) + except AssertionError: + print(f"Error num_images on {self.sp_rank}", current_num_images) + print("current_batch", current_batch) + print( + f"Error len(torch.where(batches[i] == image_token_id)[0].tolist() on {self.sp_rank}:", + len(torch.where(current_batch == image_token_id)[0].tolist()), + ) + print(f"Error len(current_batch_images) on {self.sp_rank}:", len(current_batch_images)) + raise AssertionError + + # Split for sequence parallelism + for i in range(len(batches)): + image_token_indices = torch.where(batches[i] == image_token_id)[0].tolist() + image_ids = torch.arange(0, len(image_token_indices), dtype=torch.int32) + batches[i] = extract_local_input_ids( + batches[i], image_token_indices, self.sp_rank, self.sp_degree, self.tokenizer.bos_token_id + ) + label_batches[i] = extract_local_input_ids( + label_batches[i], image_token_indices, self.sp_rank, self.sp_degree, self.tokenizer.bos_token_id + ) + batch_images[i] = torch.concat( + extract_local_from_list(batch_images[i], self.sp_rank, self.sp_degree), dim=0 + ) + H, W = batch_images[i].size(-2), batch_images[i].size(-1) + batch_images[i] = batch_images[i].reshape(-1, 3, W, H) + num_images = len(batch_images[i]) + + try: + assert num_images == len(torch.where(batches[i] == image_token_id)[0].tolist()) + except AssertionError: + print(f"Error num_images on {self.sp_rank}", num_images) + print("batches[i]", batches[i]) + print( + f"Error len(torch.where(batches[i] == image_token_id)[0].tolist() on {self.sp_rank}:", + len(torch.where(batches[i] == image_token_id)[0].tolist()), + ) + print(f"Error batch_images[i] on {self.sp_rank}:", batch_images[i].shape) + raise AssertionError + position_ids[i] = extract_local_position_ids( + position_ids[i], image_token_indices, image_ids, self.sp_rank, self.sp_degree, NUM_TOKENS_PER_IMAGE - 1 + ) + + input_ids = torch.nn.utils.rnn.pad_sequence( + batches, batch_first=True, padding_value=self.tokenizer.pad_token_id + ) + labels = torch.nn.utils.rnn.pad_sequence(label_batches, batch_first=True, padding_value=IGNORE_INDEX) + seqlens_in_batch = [torch.tensor(x) for x in seqlens_in_batch] + seqlens_in_batch = torch.stack(seqlens_in_batch, axis=0) + seqlens_in_batch = seqlens_in_batch.flatten() + position_ids = torch.nn.utils.rnn.pad_sequence(position_ids, batch_first=True, padding_value=-1) + + if batch_images: + batch_images = [torch.unbind(images) for images in batch_images] + flat_batch_images = [item for sublist in batch_images for item in sublist] + else: + flat_batch_images = None + batch = dict( + input_ids=input_ids, + labels=labels, + # notice that we inject attention mask here + attention_mask=input_ids.ne(self.tokenizer.pad_token_id), + seqlens_in_batch=seqlens_in_batch, + media={"image": flat_batch_images}, + media_config={"image": {}}, + position_ids=position_ids, + ) + return batch + + +def make_supervised_data_module( + tokenizer: PreTrainedTokenizer, + data_args: DataArguments, + training_args: TrainingArguments, +) -> Dict: + """Make dataset and collator for supervised fine-tuning. + This function is originally implemented by the LLaVA team and + modified by Jason Lu, Haotian Tang and Ligeng Zhu.""" + datasets_mixture.register_datasets_mixtures() + + from .builder import build_dataset + + train_dataset = build_dataset(data_args.data_mixture, data_args, training_args, tokenizer) + training_args.sample_lens = [len(d) for d in train_dataset.datasets] + + PROCESS_GROUP_MANAGER = get_pg_manager() + if PROCESS_GROUP_MANAGER is None: + data_collator = DataCollator(tokenizer=tokenizer) + else: + sp_degree = training_args.seq_parallel_size + sp_rank = PROCESS_GROUP_MANAGER.sp_rank + ring_degree = PROCESS_GROUP_MANAGER.ring_degree + ring_type = PROCESS_GROUP_MANAGER.ring_type + data_collator = DataCollatorForSupervisedDatasetSeqParallel( + tokenizer=tokenizer, + data_args=data_args, + training_args=training_args, + sp_degree=sp_degree, + sp_rank=sp_rank, + ring_degree=ring_degree, + ring_type=ring_type, + ) + + return dict( + train_dataset=train_dataset, + data_collator=data_collator, + ) diff --git a/llava/data/datasets_mixture.py b/llava/data/datasets_mixture.py new file mode 100644 index 0000000000000000000000000000000000000000..62a53fcf5fe69256e5cd4b005ab7f28f63d9fe7c --- /dev/null +++ b/llava/data/datasets_mixture.py @@ -0,0 +1,80 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import warnings +from dataclasses import dataclass, field + + +@dataclass +class Dataset: + dataset_name: str + dataset_type: str = field(default="torch") + data_path: str = field(default=None, metadata={"help": "Path to the training data."}) + meta_path: str = field(default=None, metadata={"help": "Path to the meta data for webdataset."}) + image_path: str = field(default=None, metadata={"help": "Path to the training image data."}) + speech_path: str = field(default=None, metadata={"help": "Path to the training speech data."}) + caption_choice: str = field(default=None, metadata={"help": "Path to the caption directory for recaption."}) + description: str = field( + default=None, + metadata={ + "help": "Detailed desciption of where the data is from, how it is labelled, intended use case and the size of the dataset." + }, + ) + test_script: str = (None,) + maintainer: str = (None,) + ############## ############## ############## ############## ############## ############## + caption_choice: str = field(default=None, metadata={"help": "Path to the captions for webdataset."}) + caption_choice_2: str = field(default=None, metadata={"help": "Path to the captions for webdataset."}) + start_idx: float = field(default=-1, metadata={"help": "Start index of the dataset."}) + end_idx: float = field(default=-1, metadata={"help": "Start index of the dataset."}) + + +DATASETS_LEGACY = {} + + +def add_dataset(dataset): + if dataset.dataset_name in DATASETS_LEGACY: + # make sure the data_name is unique + warnings.warn(f"{dataset.dataset_name} already existed in DATASETS. Make sure the name is unique.") + assert "+" not in dataset.dataset_name, "Dataset name cannot include symbol '+'." + DATASETS_LEGACY.update({dataset.dataset_name: dataset}) + + +def register_datasets_mixtures(): + ############## ############## ############## ############## ############## ############## + # Audio Datasets + ############## ############## ############## ############## ############## ############## + + data_mixture_1 = Dataset( + dataset_name="data_mixture_1", + dataset_type="torch", + data_path="/path/to/your/data_mixture_1/train.json", + ) + add_dataset(data_mixture_1) + + data_mixture_2 = Dataset( + dataset_name="data_mixture_2", + dataset_type="torch", + data_path="/path/to/your/data_mixture_2/train.json", + ) + add_dataset(data_mixture_2) + # Add more data mixtures below \ No newline at end of file diff --git a/llava/data/registry/datasets/audio_test.yaml b/llava/data/registry/datasets/audio_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..144c86b4bee8d3bbe67b32b7e7948e6f3c8e4bb9 --- /dev/null +++ b/llava/data/registry/datasets/audio_test.yaml @@ -0,0 +1,97 @@ +--- +Clotho-AQA-AQA: + _target_: llava.data.LLaVADataset + data_path: Clotho-AQA-AQA/test.json +Music-AVQA-AQA_All: + _target_: llava.data.LLaVADataset + data_path: Music-AVQA-AQA_All/test.json +CochlScene-SceneClassification: + _target_: llava.data.LLaVADataset + data_path: CochlScene-SceneClassification/test.json +NSynth-Source: + _target_: llava.data.LLaVADataset + data_path: NSynth-Source/test.json +NSynth-Instrument: + _target_: llava.data.LLaVADataset + data_path: NSynth-Instrument/test.json +FSD50k-EventClassification: + _target_: llava.data.LLaVADataset + data_path: FSD50k-EventClassification/test.json +Clotho-v2-AudioCaptioning: + _target_: llava.data.LLaVADataset + data_path: Clotho-v2-AudioCaptioning/test.json +audiocaps-AudioCaptioning: + _target_: llava.data.LLaVADataset + data_path: audiocaps-AudioCaptioning/test.json +ravdess-EmotionClassification: + _target_: llava.data.LLaVADataset + data_path: ravdess-EmotionClassification/val.json +GTZAN-GenreClassification: + _target_: llava.data.LLaVADataset + data_path: GTZAN-GenreClassification/test.json +UrbanSound8K-EventClassification: + _target_: llava.data.LLaVADataset + data_path: UrbanSound8K-EventClassification/train.json +Medley-solos-DB-InstrClassification: + _target_: llava.data.LLaVADataset + data_path: Medley-solos-DB-InstrClassification/test.json +ESC50-EventClassification: + _target_: llava.data.LLaVADataset + data_path: ESC50-EventClassification/train.json +CREMA-D-EmotionClassification: + _target_: llava.data.LLaVADataset + data_path: CREMA-D-EmotionClassification/test.json +IEMOCAP-EmotionClassification: + _target_: llava.data.LLaVADataset + data_path: IEMOCAP-EmotionClassification/test.json +MELD-EmotionClassification: + _target_: llava.data.LLaVADataset + data_path: MELD-EmotionClassification/test.json +MELD-SentimentClassification: + _target_: llava.data.LLaVADataset + data_path: MELD-SentimentClassification/test.json +MMAU: + _target_: llava.data.LLaVADataset + data_path: MMAU/test.json +MMAU-mini: + _target_: llava.data.LLaVADataset + data_path: MMAU/test-mini.json +AudioEntailmentQA: + _target_: llava.data.LLaVADataset + data_path: AudioEntailmentQA/test.json +SPGI-ASR: + _target_: llava.data.LLaVADataset + data_path: SPGI-ASR/val.json +SWBD-ASR: + _target_: llava.data.LLaVADataset + data_path: SWBD-ASR/val.json +LibriSpeech-ASR-clean: + _target_: llava.data.LLaVADataset + data_path: LibriSpeech-ASR/test_clean.json +LibriSpeech-ASR-other: + _target_: llava.data.LLaVADataset + data_path: LibriSpeech-ASR/test_other.json +VoxPopuli-ASR: + _target_: llava.data.LLaVADataset + data_path: VoxPopuli-ASR/test.json +Europarl-ASR: + _target_: llava.data.LLaVADataset + data_path: Europarl-ASR/test.json +CV-ASR: + _target_: llava.data.LLaVADataset + data_path: CV-ASR/test.json +GigaSpeech-ASR: + _target_: llava.data.LLaVADataset + data_path: GigaSpeech-ASR/test.json +CompA-R-AQA: + _target_: llava.data.LLaVADataset + data_path: CompA-R-AQA/test.json +MuschoMusicQA: + _target_: llava.data.LLaVADataset + data_path: MuschoMusicQA/test.json +CMM: + _target_: llava.data.LLaVADataset + data_path: CMM/test.json +AIR-Bench: + _target_: llava.data.LLaVADataset + data_path: AIR-Bench/test.json diff --git a/llava/data/registry/datasets/default.yaml b/llava/data/registry/datasets/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1dc18dfebe7b092e029ba6661c41925d503477ae --- /dev/null +++ b/llava/data/registry/datasets/default.yaml @@ -0,0 +1,5 @@ +--- +dummy: + _target_: llava.data.DummyDataset + num_instances: 10000 + comments: dummy dataset for testing diff --git a/llava/data/registry/mixtures.yaml b/llava/data/registry/mixtures.yaml new file mode 100644 index 0000000000000000000000000000000000000000..70f69accb593595bd978de144a8b571b6c24b3b6 --- /dev/null +++ b/llava/data/registry/mixtures.yaml @@ -0,0 +1,78 @@ +--- +audio_speech_all: + -CV-ASR_1 + -MELD-EmotionClassification+ + -BBCSoundEffects-AudioDescription + -SWBD-ASR_1 + -WavCaps-SoundBible-AudioCaptioning + -AudioSet-Speech-Audio-QA + -SONYC-UST-EventClassification + -VoxPopuli-ASR_1 + -FSD50k-EventClassification + -SalmonnQA + -emov-db-EmotionClassification + -LLARK_MagnaTagATune-mir+tess-EmotionClassification + -Europarl-ASR_1 + -jl-corpus-EmotionClassification + -Ego-10-AudioCaptioning + -SPGI-ASR_1 + -CREMA-D-EmotionClassification + -MusicBenchQA + -WavCaps-BBC_Sound_Effects-AudioCaptioning + -NSynth-Instrument + -SpokenSquadQA + -NSynth-MIR + -AudioEntailmentQA + -GigaSpeech-ASR_1 + -WavCaps-AudioSet_SL-AudioCaptioning + -NonSpeech7k-EventClassification + -chime-home-EventClassification + -MusicCaps-AudioCaptioning + -LP-MusicCaps-MSD-AudioCaptioning + -Ego-30-AudioCaptioning + -NSynth-Source+Clotho-v2-AudioCaptioning + -LP-MusicCaps-MC-AudioCaptioning + -Clotho-AQA-EventClassification + -WavCaps-FreeSound-AudioCaptioning + -LLARK_MagnaTagATune-reasoning + -AudioSet-Temporal-Speech-Audio-QA + -TUT-EventClassification + -ESC50-EventClassification + -WavText5K-Tagging + -MELD-SentimentClassification + -Music-AVQA-AQA_All + -Music-AVQA-AVQA_All + -MACS-AudioCaptioning + -Medley-solos-DB-InstrClassification + -AudioSet-EventClassification + -OMGEmotion-EmotionClassification + -FMA-GenreClassification + -Epidemic_sound-AudioCaptioning + -CochlScene-SceneClassification + -LLARK_FMA-reasoning + -ravdess-EmotionClassification + -CompA-R-AQA + -MU-LLAMA-AQA + -musdbhq-InstrClassification + -UrbanSound8K-EventClassification + -audiocaps-AudioCaptioning + -VocalSound-VocalClassification + -CLAP_freesound-AudioCaptioning + -MMAUQA + -SongDescriber-AudioCaptioning + -HeySQuADQA + -Mira-AudioCaptioning + -Clotho-AQA-AQA + -LibriSpeech-ASR_1 + -IEMOCAP-EmotionClassification + -AudioSetFullwoAudioMusicCaps-EventClassification + -MSP-PODCAST-Publish-1.9-EmotionClassification + -OpenAQA-AQA + -SoundDescs-AudioDescription + -LibriSQA + -LLARK_FMA-mir + -LP-MusicCaps-MTT-AudioCaptioning + -GTZAN-GenreClassification + -musdbhq-captioning + -YesNoQA + diff --git a/llava/entry.py b/llava/entry.py new file mode 100644 index 0000000000000000000000000000000000000000..44348e95eea2598a030f6953ff2098e133bed68f --- /dev/null +++ b/llava/entry.py @@ -0,0 +1,60 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import os +import typing +from typing import List, Optional + +if typing.TYPE_CHECKING: + from transformers import PreTrainedModel +else: + PreTrainedModel = None + +__all__ = ["load"] + + +def load( + model_path: str, + model_base: Optional[str] = None, + devices: Optional[List[int]] = None, + **kwargs, +) -> PreTrainedModel: + import torch + + from llava.conversation import auto_set_conversation_mode + from llava.mm_utils import get_model_name_from_path + from llava.model.builder import load_pretrained_model + + auto_set_conversation_mode(model_path) + + model_name = get_model_name_from_path(model_path) + model_path = os.path.expanduser(model_path) + if os.path.exists(os.path.join(model_path, "model")): + model_path = os.path.join(model_path, "model") + + # Set `max_memory` to constrain which GPUs to use + if devices is not None: + assert "max_memory" not in kwargs, "`max_memory` should not be set when `devices` is set" + kwargs.update(max_memory={device: torch.cuda.get_device_properties(device).total_memory for device in devices}) + + model = load_pretrained_model(model_path, model_name, model_base, **kwargs)[1] + return model diff --git a/llava/eval/__init__.py b/llava/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d807fb5a2b32ea471a6bfec0d57694f57f91e377 --- /dev/null +++ b/llava/eval/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +import os + +from llava.utils import io + +__all__ = ["EVAL_ROOT", "TASKS"] + + +EVAL_ROOT = "scripts/eval" +TASKS = io.load(os.path.join(os.path.dirname(__file__), "registry_audio.yaml")) diff --git a/llava/eval/eval_audio_bench.py b/llava/eval/eval_audio_bench.py new file mode 100644 index 0000000000000000000000000000000000000000..b3017aef3ba10519f507e4d878d834c7ab3d7dbf --- /dev/null +++ b/llava/eval/eval_audio_bench.py @@ -0,0 +1,117 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +import argparse +import csv +import itertools +import json +import os + +import torch +from datasets import load_dataset +from tqdm import tqdm + +import llava +from llava import conversation as conversation_lib +from llava.data.builder import DATASETS +from llava.eval.mmmu_utils.eval_utils import parse_choice +from llava.utils import distributed as dist +from llava.utils import io +from llava.utils.logging import logger + + +def load_existing_ids(output_file): + if not os.path.exists(output_file): + return set(), [] + try: + with open(output_file, "r") as f: + lines = f.readlines() + outputs = [json.loads(line) for line in lines] + processed_ids = {item["id"] for item in outputs} + return processed_ids, outputs + except Exception as e: + print(f"Error loading existing outputs: {e}") + return set(), [] + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default=None) + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--task", type=str, default=None) + parser.add_argument("--conv-mode", type=str, default="auto") + parser.add_argument("--generation-config", type=json.loads) + parser.add_argument("--output-dir", type=str, default=None) + args = parser.parse_args() + + # Set up distributed environment + dist.init() + devices = range(dist.local_rank(), torch.cuda.device_count(), dist.local_size()) + torch.cuda.set_device(devices[0]) + + # Load stage 3 model with line 56 + model = llava.load(args.model_base, model_base=None, devices=devices) + # Uncomment line 58-63 to load stage 3.5 model on top of stage 3 for thinking mode and long audio mode + # model = PeftModel.from_pretrained( + # model, + # args.model_path, + # device_map="auto", + # torch_dtype=torch.float16, + # ) + # Set up generation config + generation_config = model.default_generation_config + if args.generation_config is not None: + generation_config.update(**args.generation_config) + + # Load data and chunk it + json_file = DATASETS[args.task]["data_path"] + instances = io.load(json_file) + instances = instances[dist.rank() :: dist.size()] + + output_path = os.path.join(args.output_dir, f"outputs_{args.task}.jsonl") + processed_ids, outputs = load_existing_ids(output_path) + + count = len(outputs) + # Run inference + new_outputs = [] + for instance in tqdm(instances, disable=not dist.is_main()): + uuid = instance["id"] + sound_path = instance["sound"] + + if sound_path in processed_ids: + continue # Skip if already processed + sound = llava.Sound(sound_path) + conversations = instance["conversations"] + question = conversations[0]["value"] + + response = model.generate_content([sound, question], generation_config=generation_config) + + print("response", response) + + output = {"id": sound_path, "question": question, "gt_answer": conversations[1]["value"], "pred": response} + new_outputs.append(output) + count = count +1 + if count % 20 == 0: + # Gather and save outputs + if dist.size() > 1: + outputs_new = dist.gather(new_outputs, dst=0) + if dist.is_main(): + outputs_new = list(itertools.chain(*outputs_new)) + final_outputs = outputs + outputs_new + io.save(os.path.join(args.output_dir, f"outputs_{args.task}.jsonl"), final_outputs) + else: + final_outputs = outputs + new_outputs + io.save(os.path.join(args.output_dir, f"outputs_{args.task}.jsonl"), final_outputs) + if dist.size() > 1: + new_outputs = dist.gather(new_outputs, dst=0) + if not dist.is_main(): + return + new_outputs = list(itertools.chain(*new_outputs)) + final_outputs = outputs + new_outputs + io.save(os.path.join(args.output_dir, "outputs_"+str(args.task)+".jsonl"), final_outputs) + +if __name__ == "__main__": + main() diff --git a/llava/eval/mmmu_utils/__pycache__/eval_utils.cpython-311.pyc b/llava/eval/mmmu_utils/__pycache__/eval_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62c020ee5b25a39cc6679b0f7f28235d1db3120c Binary files /dev/null and b/llava/eval/mmmu_utils/__pycache__/eval_utils.cpython-311.pyc differ diff --git a/llava/eval/mmmu_utils/eval_utils.py b/llava/eval/mmmu_utils/eval_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..42c3cc6bc917ffa3f5af7ed7c2ec601ef8dbfbc1 --- /dev/null +++ b/llava/eval/mmmu_utils/eval_utils.py @@ -0,0 +1,61 @@ +# This file is originated from the official MMMU codebase: +# https://github.com/MMMU-Benchmark/MMMU + +import random + +import numpy as np + + +def parse_choice(response, all_choices, index2ans=None): + """ + Parse the prediction from the generated response. + Return the predicted index e.g., A, B, C, D. + """ + for char in [",", ".", "!", "?", ";", ":", "'"]: + response = response.strip(char) + response = " " + response + " " # add space to avoid partial match + + index_ans = True + ans_with_brack = False + candidates = [] + for choice in all_choices: # e.g., (A) (B) (C) (D) + if f"({choice})" in response: + candidates.append(choice) + ans_with_brack = True + + if len(candidates) == 0: + for choice in all_choices: # e.g., A B C D + if f" {choice} " in response: + candidates.append(choice) + + # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example + if len(candidates) == 0 and len(response.split()) > 5 and index2ans is not None: + for index, ans in index2ans.items(): + if ans.lower() in response.lower(): + candidates.append(index) + index_ans = False # it's content ans. + + if len(candidates) == 0: # still not get answer, randomly choose one. + pred_index = random.choice(all_choices) + elif len(candidates) > 1: + start_indexes = [] + if index_ans: + if ans_with_brack: + for can in candidates: + index = response.rfind(f"({can})") + start_indexes.append(index) # -1 will be ignored anyway + # start_indexes = [generated_response.index(f'({can})') for can in candidates] + else: + for can in candidates: + index = response.rfind(f" {can} ") + start_indexes.append(index) + else: + for can in candidates: + index = response.lower().rfind(index2ans[can].lower()) + start_indexes.append(index) + # get the last one + pred_index = candidates[np.argmax(start_indexes)] + else: # if only one candidate, use it. + pred_index = candidates[0] + + return pred_index diff --git a/llava/eval/registry_audio.yaml b/llava/eval/registry_audio.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8de55c73c3d12785044fb0cbd09c847eb548a514 --- /dev/null +++ b/llava/eval/registry_audio.yaml @@ -0,0 +1,93 @@ +Clotho-AQA-AQA: + tags: + - local +Music-AVQA-AQA_All: + tags: + - local +CochlScene-SceneClassification: + tags: + - local +NSynth-Source: + tags: + - local +NSynth-Instrument: + tags: + - local +FSD50k-EventClassification: + tags: + - local +Clotho-v2-AudioCaptioning: + tags: + - local +audiocaps-AudioCaptioning: + tags: + - local +ravdess-EmotionClassification: + tags: + - local +GTZAN-GenreClassification: + tags: + - local +UrbanSound8K-EventClassification: + tags: + - local +Medley-solos-DB-InstrClassification: + tags: + - local +ESC50-EventClassification: + tags: + - local +CREMA-D-EmotionClassification: + tags: + - local +IEMOCAP-EmotionClassification: + tags: + - local +MELD-EmotionClassification: + tags: + - local +MELD-SentimentClassification: + tags: + - local +MMAU: + tags: + - local +AudioEntailmentQA: + tags: + - local +SPGI-ASR: + tags: + - local +SWBD-ASR: + tags: + - local +LibriSpeech-ASR-clean: + tags: + - local +LibriSpeech-ASR-other: + tags: + - local +VoxPopuli-ASR: + tags: + - local +Europarl-ASR: + tags: + - local +CV-ASR: + tags: + - local +GigaSpeech-ASR: + tags: + - local +CompA-R-AQA: + tags: + - local +MuschoMusicQA: + tags: + - local +CMM: + tags: + - local +AIR-Bench: + tags: + - local diff --git a/llava/media.py b/llava/media.py new file mode 100644 index 0000000000000000000000000000000000000000..8bcc0aa44034212a4649b9ef256eb5d995e48380 --- /dev/null +++ b/llava/media.py @@ -0,0 +1,47 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +__all__ = ["Media", "File", "Image", "Video", "Speech", "Sound"] + + +class Media: + pass + + +class File(Media): + def __init__(self, path: str) -> None: + self.path = path + + +class Image(File): + pass + + +class Video(File): + pass + + +class Speech(File): + pass + +class Sound(File): + pass \ No newline at end of file diff --git a/llava/mm_utils.py b/llava/mm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1f4a23d338f9a4440aeaa6f44f7f5fb0d89093fe --- /dev/null +++ b/llava/mm_utils.py @@ -0,0 +1,641 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +# dynamic_preprocess and find_closest_aspect_ratio are referenced from https://github.com/OpenGVLab/InternVL + +import base64 +import os +import tempfile +from io import BytesIO + +import numpy as np +import torch +from PIL import Image +from transformers import StoppingCriteria + +from pydub import AudioSegment +from torchvision import transforms +import soundfile as sf +from librosa import resample as librosa_resample +import whisper +import random +from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, UniformClipSampler +DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds +def get_frame_from_vcap(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None): + import cv2 + + if fps == None or frame_count == None: + # if one of fps or frame_count is None, still recompute + fps = vidcap.get(cv2.CAP_PROP_FPS) + frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) + if fps == 0 or frame_count == 0: + print(f"Video file not found. return empty images. {video_file_name}") + return [ + Image.new("RGB", (720, 720)), + ] * num_frames, 0, [0.] + + duration = frame_count / fps + frame_interval = frame_count // num_frames + if frame_interval == 0 and frame_count <= 1: + print(f"frame_interval is equal to 0. return empty image. {video_file_name}") + return [ + Image.new("RGB", (720, 720)), + ] * num_frames, 0, [0.] + # print("duration:", duration, "frames:", frame_count, "intervals:", frame_interval) + + images = [] + count = 0 + success = True + frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int) + frame_times = [frame / fps for frame in frame_indices] + while success: + # print("frame_count:", frame_count, "count:", count, "num_frames:", num_frames, "frame_interval:", frame_interval) + if frame_count >= num_frames: + success, frame = vidcap.read() + if count in frame_indices: + try: + img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + im_pil = Image.fromarray(img) + images.append(im_pil) + except BaseException: + continue + if len(images) >= num_frames: + return images, num_frames, frame_times + count += 1 + else: + # Left padding frames if the video is not long enough + success, frame = vidcap.read() + if success: + try: + img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + im_pil = Image.fromarray(img) + images.append(im_pil) + except BaseException: + continue + count += 1 + else: + break + if len(images) == 0: + raise ValueError("Did not find enough frames in the video. return empty image.") + + return images, len(images), frame_times + + +def get_frame_from_vcap_with_fps(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None): + """ + num_frames is the max number of frames the model can support. + frame_count is the number of frames in the input video. + max_fps is the max FPS of the model can support. + fps is the fps of the input video. + """ + + import random + + import cv2 + + if fps == None or frame_count == None: + # if one of fps or frame_count is None, still recompute + fps = vidcap.get(cv2.CAP_PROP_FPS) + frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) + + if fps == 0 or frame_count == 0: + print(f"Video file not found. return empty images. {video_file_name}") + empty_video_frames = int(random.uniform(2, 8 * max_fps)) + return [ + Image.new("RGB", (720, 720)), + ] * empty_video_frames, 0, [0.] + + duration = frame_count / fps + # print("duration:", duration, "frames:", frame_count, "fps:", fps, "num_frames:", num_frames, "max_fps:", max_fps) + # If the video is too long (longer than max_fps and num_frames can support), + # we will use lower fps to sample frames. + if duration >= num_frames / max_fps: + frame_interval = frame_count // num_frames + + # If the video is too short, we will skip the video if there is only one frame. + if frame_interval == 0 and frame_count <= 1: + print(f"frame_interval is equal to 0. return empty image. {video_file_name}") + empty_video_frames = int(random.uniform(2, 8 * max_fps)) + return [ + Image.new("RGB", (720, 720)), + ] * empty_video_frames, 0, [0.] + + images = [] + count = 0 + success = True + frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int) + frame_times = [frame / fps for frame in frame_indices] + while success: + if frame_count >= num_frames: + # success, frame = vidcap.read() + if count in frame_indices: + success, frame = vidcap.read() + try: + img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + im_pil = Image.fromarray(img) + images.append(im_pil) + except: + # print("Failed to read frame:", count) + continue + if len(images) >= num_frames: + return images, num_frames, frame_times + else: + success = vidcap.grab() + count += 1 + else: + # Left padding frames if the video is not long enough + success, frame = vidcap.read() + if success: + try: + img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + im_pil = Image.fromarray(img) + images.append(im_pil) + except: + # print("Failed to read frame:", count) + continue + count += 1 + else: + break + else: + frames_required = int(duration * max_fps) + frame_indices = np.linspace(0, frame_count - 1, frames_required, dtype=int) + if frames_required == 0: + print(f"frames_required is fewer than 2. Duration {duration}, return empty image.") + empty_video_frames = int(random.uniform(2, 8 * max_fps)) + return [ + Image.new("RGB", (720, 720)), + ] * empty_video_frames, 0, [0.] + elif frames_required == 1: + frame_indices = np.linspace(0, frame_count - 1, 2, dtype=int) + images = [] + count = 0 + looked = 0 + success = True + + while success: + success, frame = vidcap.read() + if success and (looked in frame_indices): + try: + img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + im_pil = Image.fromarray(img) + images.append(im_pil) + except: + continue + count += 1 + looked += 1 + frame_times = [frame / fps for frame in frame_indices] + + if len(images) == 0: + empty_video_frames = int(random.uniform(2, 8 * max_fps)) + return [ + Image.new("RGB", (720, 720)), + ] * empty_video_frames, 0, [0.] + else: + return images, len(images), frame_times + + +def opencv_extract_frames(vpath_or_bytesio, frames=6, max_fps=0.0, fps=None, frame_count=None): + """ + Extract frames from a video using OpenCV. + + Args: + vpath_or_bytesio (str or BytesIO): Path to the video file or BytesIO object containing the video. + frames (int): Number of frames to extract from the video. + fps (float): Frames per second of the video. If 0.0, the function will extract frames at equal intervals. + + Returns: + list: List of PIL Images extracted from the video. + + Raises: + NotImplementedError: If the type of `vpath_or_bytesio` is not supported. + """ + import cv2 + if isinstance(vpath_or_bytesio, str): + vidcap = cv2.VideoCapture(vpath_or_bytesio) + if max_fps > 0.0: + return get_frame_from_vcap_with_fps( + vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio + ) + return get_frame_from_vcap( + vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio + ) + elif isinstance(vpath_or_bytesio, (BytesIO,)): + # assuming mp4 + with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video: + temp_video.write(vpath_or_bytesio.read()) + temp_video_name = temp_video.name + vidcap = cv2.VideoCapture(temp_video_name) + if max_fps > 0.0: + return get_frame_from_vcap_with_fps( + vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name + ) + return get_frame_from_vcap( + vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name + ) + else: + raise NotImplementedError(type(vpath_or_bytesio)) + + +def load_image_from_base64(image): + return Image.open(BytesIO(base64.b64decode(image))) + + +def expand2square(pil_img, background_color): + """ + Expand the given PIL image to a square shape by adding padding. + + Parameters: + - pil_img: The PIL image to be expanded. + - background_color: The color of the padding to be added. + + Returns: + - The expanded PIL image. + + If the image is already square, it is returned as is. + If the image is wider than it is tall, padding is added to the top and bottom. + If the image is taller than it is wide, padding is added to the left and right. + """ + width, height = pil_img.size + if pil_img.mode == "L": + background_color = background_color[0] + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def dynamic_preprocess(image, min_num=1, max_num=12, image_size=384, use_thumbnail=True): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = { + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num + } + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + +def dynamic_s2_preprocess(image, s2_scales=[384, 768, 1152], max_num=12, image_size=384): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + min_num = (s2_scales[-1] // s2_scales[0]) ** 2 # at least use number of tiles as the largest scale + + processed_images = [] + + ########################################################################################## + ############# Add tiles for all but the last scale using fixed squre ratio ############### + ########################################################################################## + + for scale in s2_scales[:-1]: + target_width = image_size * (scale // s2_scales[0]) + target_height = image_size * (scale // s2_scales[0]) + blocks = (scale // s2_scales[0]) ** 2 + + # resize the image + resized_img = image.resize((target_width, target_height)) + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + + ########################################################################################## + ################ Add tiles for the last scale using dynamic aspect ratio ################# + ########################################################################################## + + # calculate the existing image aspect ratio + target_ratios = { + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num + } + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + + return processed_images, (target_aspect_ratio[1], target_aspect_ratio[0]) + + + +def dynamic_s2_process_images_and_prompt(images, data_args, image_folder=None): + idx = 0 + all_images = [] + all_block_size = [] + for img in images: + processed_images, block_size = process_image(img, data_args, image_folder, enable_dynamic_s2=True) + all_images.append(processed_images) + all_block_size.append(block_size) + idx += 2 + if all_images: + all_images = torch.cat(all_images) + else: + all_images = None + return all_images, all_block_size + + +def process_image( + image_file, data_args, image_folder, enable_dynamic_res=False, enable_dynamic_s2=False, max_tiles=None +): + processor = data_args.image_processor + if isinstance(image_file, str): + if image_folder is not None: + image = Image.open(os.path.join(image_folder, image_file)).convert("RGB") + else: + image = Image.open(image_file).convert("RGB") + else: + # image is stored in bytearray + image = image_file + image = image.convert("RGB") + if hasattr(data_args.image_processor, "crop_size"): + # CLIP vision tower + crop_size = data_args.image_processor.crop_size + else: + # SIGLIP vision tower + assert hasattr(data_args.image_processor, "size") + crop_size = data_args.image_processor.size + if "dynamic_s2" in data_args.image_aspect_ratio and enable_dynamic_s2: + assert crop_size["height"] == crop_size["width"] + images, block_size = dynamic_s2_preprocess( + image, s2_scales=data_args.s2_scales, max_num=data_args.max_tiles, image_size=crop_size["height"] + ) + images = [processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images] + return torch.stack(images), block_size + if "dynamic" in data_args.image_aspect_ratio and enable_dynamic_res: + assert crop_size["height"] == crop_size["width"] + if max_tiles is not None: + max_num = max_tiles + else: + max_num = data_args.max_tiles + images = dynamic_preprocess(image, min_num=data_args.min_tiles, max_num=max_num, image_size=crop_size["height"]) + images = [processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images] + return torch.stack(images) + + if data_args.image_aspect_ratio == "resize": + image = image.resize((crop_size["width"], crop_size["height"])) + if data_args.image_aspect_ratio == "pad": + + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean)) + image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0] + else: + # Using default behavior of the vision encoder + # For CLIP, default is central crop + # For Radio, default is central crop + # For Siglip, default is resize + # For InternVIT, default is resize + image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0] + return image + +def get_num_windows(T, sr, max_num_window=5): + + window_length = int(30.0 * sr) + window_overlap = int(0.0 * sr) + max_num_window = max_num_window + + num_windows = 1 + if T <= window_length: + num_windows = 1 + full_length = window_length + elif T >= (max_num_window * window_length - (max_num_window - 1) * window_overlap): + num_windows = max_num_window + full_length = (max_num_window * window_length - (max_num_window - 1) * window_overlap) + else: + num_windows = 1 + int(np.ceil((T - window_length) / float(window_length - window_overlap))) + full_length = num_windows * window_length - (num_windows - 1) * window_overlap + + return num_windows, full_length + +def load_audio(file_path, target_sr=16000, duration=30.0, start=0.0): + if file_path.endswith('.mp3'): + audio = AudioSegment.from_file(file_path) + if len(audio) > (start + duration) * 1000: + audio = audio[start * 1000:(start + duration) * 1000] + + if audio.frame_rate != target_sr: + audio = audio.set_frame_rate(target_sr) + + if audio.channels > 1: + audio = audio.set_channels(1) + + data = np.array(audio.get_array_of_samples()) + if audio.sample_width == 2: + data = data.astype(np.float32) / np.iinfo(np.int16).max + elif audio.sample_width == 4: + data = data.astype(np.float32) / np.iinfo(np.int32).max + else: + raise ValueError("Unsupported bit depth: {}".format(audio.sample_width)) + + else: + with sf.SoundFile(file_path) as audio: + original_sr = audio.samplerate + channels = audio.channels + + max_frames = int((start + duration) * original_sr) + + audio.seek(int(start * original_sr)) + frames_to_read = min(max_frames, len(audio)) + data = audio.read(frames_to_read) + + if data.max() > 1 or data.min() < -1: + data = data / max(abs(data.max()), abs(data.min())) + + if original_sr != target_sr: + if channels == 1: + data = librosa_resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr) + else: + data = librosa_resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0] + else: + if channels != 1: + data = data.T[0] + + if data.min() >= 0: + data = 2 * data / abs(data.max()) - 1.0 + else: + data = data / max(abs(data.max()), abs(data.min())) + + assert len(data.shape) == 1, data.shape + return data + +def process_images(images, image_processor, model_cfg, enable_dynamic_res=False, max_tiles=None): + model_cfg.image_processor = image_processor + new_images = [ + process_image(image, model_cfg, None, enable_dynamic_res=enable_dynamic_res, max_tiles=max_tiles) + for image in images + ] + + if all(x.shape == new_images[0].shape for x in new_images): + if len(new_images[0].shape) == 4: + new_images = torch.cat(new_images, dim=0) + elif len(new_images[0].shape) == 3: + new_images = torch.stack(new_images, dim=0) + else: + raise ValueError(f"new_images rank does not equal to 4, rank: {len(new_images[0].shape)}") + else: + raise ValueError("The shape of images in new_images is different!") + return new_images + +def process_sounds(sounds): + sounds = torch.tensor(sounds) + return sounds + +def process_sound_masks(masks): + masks = torch.tensor(masks[0]) + return masks + +def tokenizer_image_token(prompt, tokenizer, return_tensors=None): + return tokenizer(prompt, return_tensors=return_tensors).input_ids[0] + +def is_gemma_tokenizer(tokenizer): + return "gemma" in tokenizer.__class__.__name__.lower() + + +def get_model_name_from_path(model_path): + model_path = model_path.strip("/") + model_paths = model_path.split("/") + if model_paths[-1].startswith("checkpoint-"): + return model_paths[-2] + "_" + model_paths[-1] + else: + return model_paths[-1] + +class KeywordsStoppingCriteria(StoppingCriteria): + def __init__(self, keywords, tokenizer, input_ids): + self.keywords = keywords + self.keyword_ids = [] + self.max_keyword_len = 0 + for keyword in keywords: + cur_keyword_ids = tokenizer(keyword).input_ids + if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: + cur_keyword_ids = cur_keyword_ids[1:] + if len(cur_keyword_ids) > self.max_keyword_len: + self.max_keyword_len = len(cur_keyword_ids) + self.keyword_ids.append(torch.tensor(cur_keyword_ids)) + self.tokenizer = tokenizer + self.start_len = input_ids.shape[1] + + def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) + self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] + for keyword_id in self.keyword_ids: + if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all(): + return True + outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] + for keyword in self.keywords: + if keyword in outputs: + return True + return False + + def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + outputs = [] + for i in range(output_ids.shape[0]): + outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) + return all(outputs) diff --git a/llava/model/FloatPointQuantizeTorch.py b/llava/model/FloatPointQuantizeTorch.py new file mode 100644 index 0000000000000000000000000000000000000000..b01c6f8ad008cdcd01c51a85399a194032df05b1 --- /dev/null +++ b/llava/model/FloatPointQuantizeTorch.py @@ -0,0 +1,85 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +import math + +import torch + + +def floatExMy_quantize_torch(x, e_bit, m_bit, stochastic): + sign, x_abs = x.sign(), x.abs() + Elow, Ehigh, Mhigh = -(2 ** (e_bit - 1)) + 2, 2 ** (e_bit - 1), 2**m_bit + expo = torch.floor(torch.log2(x_abs)) + expo = torch.clamp(expo, min=Elow, max=Ehigh) + mant = x_abs / torch.exp2(expo) + + mant_int = torch.floor(mant) + mant_frac = mant - mant_int + mant_frac = mant_frac * Mhigh + if stochastic: + noise = mant_frac.new(mant_frac.shape).uniform_(-0.5, 0.5) + mant_frac.add_(noise) + mant_frac = torch.round(mant_frac) + + mant_q = mant_int + mant_frac / Mhigh + y = sign * (2**expo) * mant_q + y = y.to(x) + + return y + + +def floatExM0_quantize_torch(x, e_bit, stochastic): + sign, x_abs = x.sign(), x.abs() + Elow, Ehigh = -(2 ** (e_bit - 1)) + 1, 2 ** (e_bit - 1) + expo = torch.log2(x_abs) + if stochastic: + noise = expo.new(expo.shape).uniform_(-0.5, 0.5) + expo.add(noise) + log_bias = math.log2(4 / 3) - 1 / 2 + expo.add(torch.ones_like(expo) * log_bias) + expo = torch.clamp(expo, min=Elow - 1, max=Ehigh) + expo = torch.round(expo) + + y = sign * (2**expo) * (expo > Elow) # When underflow, set the value to 0 + y = y.to(x) + + return y + + +def Dynamic_quantize_torch(x, bit, stochastic): + if stochastic: + raise NotImplementedError("Dynamic Tree quantization does not support stochastic") + sign, x_abs = x.sign(), x.abs() + expo = torch.ceil(torch.log10(x_abs)) + expo = torch.clamp(expo, min=2 - bit) + mant = (10 * x_abs / torch.pow(10, expo) - 1) / 9 # Range from 0 - 1 + + mant_frac = mant * 2 ** (bit - 2 - expo.abs()) + mant_frac = torch.round(mant_frac) + mant_frac = mant_frac / (2 ** (bit - 2 - expo.abs())) * 9 + 1 + y = sign * (10**expo) * mant_frac / 10 + + zero_mask = y.abs() > 1.01 * 10 ** (1 - bit) + y = y * zero_mask + y = y.to(x) + return y + + +def ZeroDynamic_quantize_torch(x, bit, stochastic): + if stochastic: + raise NotImplementedError("Dynamic Tree quantization does not support stochastic") + sign, x_abs = x.sign(), x.abs() + expo = torch.ceil(torch.log10(x_abs)) + expo = torch.clamp(expo, min=2 - bit) + mant = (10 * x_abs / torch.pow(10, expo) - 1) / 9 # Range from 0 - 1 + + mant_frac = mant * 2 ** (bit - 2 - expo.abs()) + mant_frac = torch.round(mant_frac) + mant_frac = mant_frac / (2 ** (bit - 2 - expo.abs())) * 9 + 1 + y = sign * (10**expo) * mant_frac / 10 + + y = y.to(x) + return y diff --git a/llava/model/FloatPointQuantizeTriton.py b/llava/model/FloatPointQuantizeTriton.py new file mode 100644 index 0000000000000000000000000000000000000000..e619c3739ec968c92460b3d8d9efd2736fd07d44 --- /dev/null +++ b/llava/model/FloatPointQuantizeTriton.py @@ -0,0 +1,199 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +import math +import struct + +import numpy as np +import torch +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +segment_size = 1024**3 + + +def floatExMy_quantize_triton(x, e_bit, m_bit, stochastic): + x_ori_shape = x.shape + x = x.view(-1) + + n_elements = x.numel() + + if n_elements <= segment_size: + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + y = torch.empty_like(x) + + if x.dtype in [torch.bfloat16, torch.float32]: + if stochastic: + noise = x.new(x.shape).uniform_(-0.5, 0.5) + _floatExMy_stochastic_quantize_kernel[grid](x, noise, y, n_elements, e_bit, m_bit) + else: + _floatExMy_quantize_kernel[grid](x, y, n_elements, e_bit, m_bit) + torch.cuda.synchronize() + else: + raise NotImplementedError(f"Other data format {x.dtype} for float quantization triton") + else: # Triton will break when x.numel > 2 * 1024 ** 3 + num_segments = n_elements // segment_size + 1 + split_size = [segment_size] * (num_segments - 1) + [n_elements - segment_size * (num_segments - 1)] + x_list = x.split(split_size) + y_list = [] + del x + + for x in x_list: + n_elements = x.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + y = torch.empty_like(x) + + if x.dtype in [torch.bfloat16, torch.float32]: + if stochastic: + noise = x.new(x.shape).uniform_(-0.5, 0.5) + _floatExMy_stochastic_quantize_kernel[grid](x, noise, y, n_elements, e_bit, m_bit) + else: + _floatExMy_quantize_kernel[grid](x, y, n_elements, e_bit, m_bit) + torch.cuda.synchronize() + else: + raise NotImplementedError(f"Other data format {x.dtype} for float quantization triton") + + y_list.append(y) + y = torch.concat(y_list) + del y_list + + y = y.reshape(x_ori_shape) + return y + + +@triton.autotune( + configs=[ + # triton.Config({'BLOCK_SIZE': 4,}, num_warps=4), + triton.Config( + { + "BLOCK_SIZE": 1024, + }, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE": 2048, + }, + num_warps=4, + ), + ], + key=["n_elements"], +) +@triton.jit +def _floatExMy_quantize_kernel( + x_ptr, + output_ptr, + n_elements, + e_bit, + m_bit, + BLOCK_SIZE: tl.constexpr, +): + if isinstance(e_bit, tl.constexpr): + ebit = e_bit.value + else: + ebit = e_bit + + if isinstance(m_bit, tl.constexpr): + mbit = m_bit.value + else: + mbit = m_bit + + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + + x = x.to(tl.float32) + sign = 1 - 2 * libdevice.signbit(x) + x_abs = tl.abs(x) + Elow = -tl.exp2((ebit - 1).to(tl.float32)) + 2 + Ehigh = tl.exp2((ebit - 1).to(tl.float32)) + Mhigh = tl.exp2(mbit.to(tl.float32)) + expo = tl.floor(tl.log2(x_abs)) + expo = tl.clamp(expo, min=Elow, max=Ehigh) + mant = x_abs / tl.exp2(expo) + + mant_int = tl.floor(mant) + mant_frac = mant - mant_int + mant_frac = mant_frac * Mhigh + # mant_frac = mant_frac + noise + mant_frac = libdevice.round(mant_frac) + + mant_q = mant_int + mant_frac / Mhigh + y = sign * tl.exp2(expo) * mant_q + y = y.to(x_ptr.dtype.element_ty) + + tl.store(output_ptr + offsets, y, mask=mask) + + +@triton.autotune( + configs=[ + # triton.Config({'BLOCK_SIZE': 4,}, num_warps=4), + triton.Config( + { + "BLOCK_SIZE": 1024, + }, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE": 2048, + }, + num_warps=4, + ), + ], + key=["n_elements"], +) +@triton.jit +def _floatExMy_stochastic_quantize_kernel( + x_ptr, + noise_ptr, + output_ptr, + n_elements, + e_bit, + m_bit, + BLOCK_SIZE: tl.constexpr, +): + if isinstance(e_bit, tl.constexpr): + ebit = e_bit.value + else: + ebit = e_bit + + if isinstance(m_bit, tl.constexpr): + mbit = m_bit.value + else: + mbit = m_bit + + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + noise = tl.load(noise_ptr + offsets, mask=mask) + + x = x.to(tl.float32) + sign = 1 - 2 * libdevice.signbit(x) + x_abs = tl.abs(x) + Elow = -tl.exp2((ebit - 1).to(tl.float32)) + 2 + Ehigh = tl.exp2((ebit - 1).to(tl.float32)) + Mhigh = tl.exp2(mbit.to(tl.float32)) + expo = tl.floor(tl.log2(x_abs)) + expo = tl.clamp(expo, min=Elow, max=Ehigh) + mant = x_abs / tl.exp2(expo) + + mant_int = tl.floor(mant) + mant_frac = mant - mant_int + mant_frac = mant_frac * Mhigh + mant_frac = mant_frac + noise + mant_frac = libdevice.round(mant_frac) + + mant_q = mant_int + mant_frac / Mhigh + y = sign * tl.exp2(expo) * mant_q + y = y.to(x_ptr.dtype.element_ty) + + tl.store(output_ptr + offsets, y, mask=mask) diff --git a/llava/model/__init__.py b/llava/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5e5494e136efbc55e4813a40184308c26fe5949b --- /dev/null +++ b/llava/model/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from .language_model.llava_llama import LlavaLlamaConfig, LlavaLlamaModel + +# FP8 related comments, development in progress (PI: ligeng zhu, haochen xi) +# NOTE: VLM + LLM +# from .language_model.qllava_qllama import QLlavaLlamaConfig, QLlavaLlamaModel +# NOTE: Linear -> fp8, similar to transformer engine +# from .language_model.qllama import QLlamaConfig, QLlamaForCausalLM, QLlamaModel +# NOTE: Linear + Activation -> fp8, haochen's iclr version +# from .language_model.qmemllama import QMemLlamaConfig, QMemLlamaForCausalLM, QMemLlamaModel +""" +TODO: + linear(weights): + simulated fp8: done + real fp8: in-progress (code already implmented) + activation: + simulated fp8: done + real fp8: in-progress (still coding) + optimizers: + current VILA: bf16 + simulated fp8: done + real fp8 + fsdp (single node): done + real fp8 + fsdp (multiple node): in-progress +1. linear fp8 +2. activation fp8 +3. fp8 infernce example (load directly from a fp8 and fwd) +4. bind fp8 related configs to QLlamaConfig {"coat_fp8_args": {}} +""" +from .language_model.fp8linearqwen2 import FP8LinearQwen2Config, FP8LinearQwen2Model +from .language_model.qllava_qllama import QLlavaLlamaConfig, QLlavaLlamaModel diff --git a/llava/model/apply_delta.py b/llava/model/apply_delta.py new file mode 100644 index 0000000000000000000000000000000000000000..767d0612da9a1fd77058e3d004d16e45d572e3ca --- /dev/null +++ b/llava/model/apply_delta.py @@ -0,0 +1,77 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +# This file is modified from https://github.com/haotian-liu/LLaVA/ + +""" +Usage: +python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta +""" +import argparse + +import torch +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llava import LlavaLlamaForCausalLM + + +def apply_delta(base_model_path, target_model_path, delta_path): + print("Loading base model") + base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + + print("Loading delta") + delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) + + print("Applying delta") + for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): + if name not in base.state_dict(): + assert name in [ + "model.mm_projector.weight", + "model.mm_projector.bias", + ], f"{name} not in base model" + continue + if param.data.shape == base.state_dict()[name].shape: + param.data += base.state_dict()[name] + else: + assert name in [ + "model.embed_tokens.weight", + "lm_head.weight", + ], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" + bparam = base.state_dict()[name] + param.data[: bparam.shape[0], : bparam.shape[1]] += bparam + + print("Saving target model") + delta.save_pretrained(target_model_path) + delta_tokenizer.save_pretrained(target_model_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--delta-path", type=str, required=True) + + args = parser.parse_args() + + apply_delta(args.base_model_path, args.target_model_path, args.delta_path) diff --git a/llava/model/builder.py b/llava/model/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..4c715ddf2eaeaa59d9807b6fdc3d31a71005067c --- /dev/null +++ b/llava/model/builder.py @@ -0,0 +1,161 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# This file is modified from https://github.com/haotian-liu/LLaVA/ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import warnings + +import torch +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PretrainedConfig + +from llava.model import LlavaLlamaModel +from llava.model.utils import is_mm_model + + +def load_pretrained_model( + model_path, + model_name, + model_base=None, + load_8bit=False, + load_4bit=False, + device_map="auto", + device="cuda", + **kwargs, +): + kwargs = {"device_map": device_map, **kwargs} + + if device != "cuda": + kwargs["device_map"] = {"": device} + + if load_8bit: + kwargs["load_in_8bit"] = True + elif load_4bit: + kwargs["load_in_4bit"] = True + kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + ) + else: + kwargs["torch_dtype"] = torch.float16 + # kwargs["torch_dtype"] = torch.bfloat16 + + if is_mm_model(model_path): + # Load LLaVA model + ## TODO @yunhao: mind fixing lora + if "lora" in model_name.lower() and model_base is None: + warnings.warn( + "There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged." + ) + if ("lora" in model_name.lower() or "dora" in model_name.lower()) and model_base is not None: + lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) + print(lora_cfg_pretrained) + print("Loading LLaVA from base model...") + config = AutoConfig.from_pretrained(model_base) + prepare_config_for_eval(config, kwargs) + model = LlavaLlamaModel.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs) + tokenizer = model.tokenizer + token_num, tokem_dim = model.llm.lm_head.out_features, model.llm.lm_head.in_features + if model.llm.lm_head.weight.shape[0] != token_num: + model.llm.lm_head.weight = torch.nn.Parameter( + torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype) + ) + model.llm.embed_tokens.weight = torch.nn.Parameter( + torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype) + ) + + print("Loading additional LLaVA weights...") + if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")): + non_lora_trainables = torch.load( + os.path.join(model_path, "non_lora_trainables.bin"), + map_location="cpu", + ) + else: + # this is probably from HF Hub + from huggingface_hub import hf_hub_download + + def load_from_hf(repo_id, filename, subfolder=None): + cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder) + return torch.load(cache_file, map_location="cpu") + + non_lora_trainables = load_from_hf(model_path, "non_lora_trainables.bin") + non_lora_trainables = { + (k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items() + } + if any(k.startswith("model.model.") for k in non_lora_trainables): + non_lora_trainables = { + (k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items() + } + model.load_state_dict(non_lora_trainables, strict=False) + + from peft import PeftModel + + print("Loading LoRA weights...") + model = PeftModel.from_pretrained(model, model_path) + print("Merging LoRA weights...") + model = model.merge_and_unload() + print("Model is loaded...") + else: + config = AutoConfig.from_pretrained(model_path) + config.resume_path = model_path + prepare_config_for_eval(config, kwargs) + model = LlavaLlamaModel(config=config, low_cpu_mem_usage=True, **kwargs) + tokenizer = model.tokenizer + else: + # Load language model + if model_base is not None: + # PEFT model + from peft import PeftModel + + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs) + print(f"Loading LoRA weights from {model_path}") + model = PeftModel.from_pretrained(model, model_path) + print(f"Merging weights") + model = model.merge_and_unload() + print("Convert to FP16...") + model.to(torch.float16) + else: + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, legacy=False) + model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + model.eval() + image_processor = None + if is_mm_model(model_path): + model.resize_token_embeddings(len(tokenizer)) + + if hasattr(model.llm.config, "max_sequence_length"): + context_len = model.config.max_sequence_length + else: + context_len = 2048 + + return tokenizer, model, image_processor, context_len + + +def prepare_config_for_eval(config: PretrainedConfig, kwargs: dict): + try: + # compatible with deprecated config convention + if getattr(config, "vision_tower_cfg", None) is None: + config.vision_tower_cfg = config.mm_vision_tower + except AttributeError: + raise ValueError(f"Invalid configuration! Cannot find vision_tower in config:\n{config}") + + config.model_dtype = kwargs.pop("torch_dtype").__str__() diff --git a/llava/model/coat/activation/__init__.py b/llava/model/coat/activation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..69923298db60fa8275bf22895fd5c54cf102e9b9 --- /dev/null +++ b/llava/model/coat/activation/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + diff --git a/llava/model/coat/activation/fake_quantization/FloatPointQuantizeTorch.py b/llava/model/coat/activation/fake_quantization/FloatPointQuantizeTorch.py new file mode 100644 index 0000000000000000000000000000000000000000..04acd2fb73f707152839e1a97acb67e6f8df873b --- /dev/null +++ b/llava/model/coat/activation/fake_quantization/FloatPointQuantizeTorch.py @@ -0,0 +1,101 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import math + +import torch + + +def floatExMy_quantize_torch(x, e_bit, m_bit, stochastic): + sign, x_abs = x.sign(), x.abs() + Elow, Ehigh, Mhigh = -(2 ** (e_bit - 1)) + 2, 2 ** (e_bit - 1), 2**m_bit + expo = torch.floor(torch.log2(x_abs)) + expo = torch.clamp(expo, min=Elow, max=Ehigh) + mant = x_abs / torch.exp2(expo) + + mant_int = torch.floor(mant) + mant_frac = mant - mant_int + mant_frac = mant_frac * Mhigh + if stochastic: + noise = mant_frac.new(mant_frac.shape).uniform_(-0.5, 0.5) + mant_frac.add_(noise) + mant_frac = torch.round(mant_frac) + + mant_q = mant_int + mant_frac / Mhigh + y = sign * (2**expo) * mant_q + y = y.to(x) + + return y + + +def floatExM0_quantize_torch(x, e_bit, stochastic): + sign, x_abs = x.sign(), x.abs() + Elow, Ehigh = -(2 ** (e_bit - 1)) + 1, 2 ** (e_bit - 1) + expo = torch.log2(x_abs) + if stochastic: + noise = expo.new(expo.shape).uniform_(-0.5, 0.5) + expo.add(noise) + log_bias = math.log2(4 / 3) - 1 / 2 + expo.add(torch.ones_like(expo) * log_bias) + expo = torch.clamp(expo, min=Elow - 1, max=Ehigh) + expo = torch.round(expo) + + y = sign * (2**expo) * (expo > Elow) # When underflow, set the value to 0 + y = y.to(x) + + return y + + +def Dynamic_quantize_torch(x, bit, stochastic): + if stochastic: + raise NotImplementedError("Dynamic Tree quantization does not support stochastic") + sign, x_abs = x.sign(), x.abs() + expo = torch.ceil(torch.log10(x_abs)) + expo = torch.clamp(expo, min=2 - bit) + mant = (10 * x_abs / torch.pow(10, expo) - 1) / 9 # Range from 0 - 1 + + mant_frac = mant * 2 ** (bit - 2 - expo.abs()) + mant_frac = torch.round(mant_frac) + mant_frac = mant_frac / (2 ** (bit - 2 - expo.abs())) * 9 + 1 + y = sign * (10**expo) * mant_frac / 10 + + zero_mask = y.abs() > 1.01 * 10 ** (1 - bit) + y = y * zero_mask + y = y.to(x) + return y + + +def ZeroDynamic_quantize_torch(x, bit, stochastic): + if stochastic: + raise NotImplementedError("Dynamic Tree quantization does not support stochastic") + sign, x_abs = x.sign(), x.abs() + expo = torch.ceil(torch.log10(x_abs)) + expo = torch.clamp(expo, min=2 - bit) + mant = (10 * x_abs / torch.pow(10, expo) - 1) / 9 # Range from 0 - 1 + + mant_frac = mant * 2 ** (bit - 2 - expo.abs()) + mant_frac = torch.round(mant_frac) + mant_frac = mant_frac / (2 ** (bit - 2 - expo.abs())) * 9 + 1 + y = sign * (10**expo) * mant_frac / 10 + + y = y.to(x) + return y diff --git a/llava/model/coat/activation/fake_quantization/FloatPointQuantizeTriton.py b/llava/model/coat/activation/fake_quantization/FloatPointQuantizeTriton.py new file mode 100644 index 0000000000000000000000000000000000000000..db85e7fc42dc54539f54400ccff280402e126f0d --- /dev/null +++ b/llava/model/coat/activation/fake_quantization/FloatPointQuantizeTriton.py @@ -0,0 +1,181 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import math +import struct + +import numpy as np +import torch +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + + +def floatExMy_quantize_triton(x, e_bit, m_bit, stochastic): + n_elements = x.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + y = torch.zeros_like(x) + + if x.dtype in [torch.bfloat16, torch.float32]: + if stochastic: + noise = x.new(x.shape).uniform_(-0.5, 0.5) + _floatExMy_stochastic_quantize_kernel[grid](x, noise, y, n_elements, e_bit, m_bit) + else: + _floatExMy_quantize_kernel[grid](x, y, n_elements, e_bit, m_bit) + else: + raise NotImplementedError(f"Other data format {x.dtype} for float quantization triton") + + return y + + +@triton.autotune( + configs=[ + # triton.Config({'BLOCK_SIZE': 4,}, num_warps=4), + triton.Config( + { + "BLOCK_SIZE": 1024, + }, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE": 2048, + }, + num_stages=1, + ), + ], + key=["n_elements"], +) +@triton.jit +def _floatExMy_quantize_kernel( + x_ptr, + output_ptr, + n_elements, + e_bit, + m_bit, + BLOCK_SIZE: tl.constexpr, +): + if isinstance(e_bit, tl.constexpr): + ebit = e_bit.value + else: + ebit = e_bit + + if isinstance(m_bit, tl.constexpr): + mbit = m_bit.value + else: + mbit = m_bit + + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + + x = x.to(tl.float32) + sign = 1 - 2 * libdevice.signbit(x) + x_abs = tl.abs(x) + Elow = -tl.exp2((ebit - 1).to(tl.float32)) + 2 + Ehigh = tl.exp2((ebit - 1).to(tl.float32)) + Mhigh = tl.exp2(mbit.to(tl.float32)) + expo = tl.floor(tl.log2(x_abs)) + expo = tl.clamp(expo, min=Elow, max=Ehigh) + mant = x_abs / tl.exp2(expo) + + mant_int = tl.floor(mant) + mant_frac = mant - mant_int + mant_frac = mant_frac * Mhigh + # mant_frac = mant_frac + noise + mant_frac = libdevice.round(mant_frac) + + mant_q = mant_int + mant_frac / Mhigh + y = sign * tl.exp2(expo) * mant_q + y = y.to(x_ptr.dtype.element_ty) + + tl.store(output_ptr + offsets, y, mask=mask) + + +@triton.autotune( + configs=[ + # triton.Config({'BLOCK_SIZE': 4,}, num_warps=4), + triton.Config( + { + "BLOCK_SIZE": 1024, + }, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE": 2048, + }, + num_stages=1, + ), + ], + key=["n_elements"], +) +@triton.jit +def _floatExMy_stochastic_quantize_kernel( + x_ptr, + noise_ptr, + output_ptr, + n_elements, + e_bit, + m_bit, + BLOCK_SIZE: tl.constexpr, +): + if isinstance(e_bit, tl.constexpr): + ebit = e_bit.value + else: + ebit = e_bit + + if isinstance(m_bit, tl.constexpr): + mbit = m_bit.value + else: + mbit = m_bit + + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + noise = tl.load(noise_ptr + offsets, mask=mask) + + x = x.to(tl.float32) + sign = 1 - 2 * libdevice.signbit(x) + x_abs = tl.abs(x) + Elow = -tl.exp2((ebit - 1).to(tl.float32)) + 2 + Ehigh = tl.exp2((ebit - 1).to(tl.float32)) + Mhigh = tl.exp2(mbit.to(tl.float32)) + expo = tl.floor(tl.log2(x_abs)) + expo = tl.clamp(expo, min=Elow, max=Ehigh) + mant = x_abs / tl.exp2(expo) + + mant_int = tl.floor(mant) + mant_frac = mant - mant_int + mant_frac = mant_frac * Mhigh + mant_frac = mant_frac + noise + mant_frac = libdevice.round(mant_frac) + + mant_q = mant_int + mant_frac / Mhigh + y = sign * tl.exp2(expo) * mant_q + y = y.to(x_ptr.dtype.element_ty) + + tl.store(output_ptr + offsets, y, mask=mask) diff --git a/llava/model/coat/activation/fake_quantization/quantize_function.py b/llava/model/coat/activation/fake_quantization/quantize_function.py new file mode 100644 index 0000000000000000000000000000000000000000..6b18d45603ecd3f6564a8dbac3810d8ce701fbb6 --- /dev/null +++ b/llava/model/coat/activation/fake_quantization/quantize_function.py @@ -0,0 +1,239 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import re + +import torch + +from .FloatPointQuantizeTorch import * +from .FloatPointQuantizeTriton import * + + +def block_cut(input, row_block, column_block, pad_block=False): + # print(input.shape) + original_shape = input.shape + # input tensor shape is M * N + if len(input.shape) > 2: + input = input.reshape(-1, input.shape[2]) + elif len(input.shape) == 2: + pass + else: + raise ValueError(f"input shape {input.shape} does not match for block cut, {input}") + M, N = input.shape[0], input.shape[1] + + if row_block == -1: + row_block = M + if column_block == -1: + column_block = N + + if pad_block: + row_remainder, col_remainder = M % row_block, N % column_block + if row_remainder: + row_pad = row_block - row_remainder + else: + row_pad = 0 + if col_remainder: + col_pad = column_block - col_remainder + else: + col_pad = 0 + + input = torch.nn.functional.pad( + input, (0, col_pad, 0, row_pad), "constant", 0 + ) # refer to torch's doc to see why + M, N = input.shape[0], input.shape[1] + row_num, column_num = M // row_block, N // column_block + else: + row_num, column_num = M // row_block, N // column_block + + assert row_num * row_block == M, f"{row_num}, {row_block}, {M}, {original_shape}" + assert column_num * column_block == N, f"{column_num}, {column_block}, {N}, {original_shape}" + # print(input.shape) + input = ( + input.reshape(row_num, row_block, column_num, column_block) + .permute(0, 2, 1, 3) + .reshape(row_num * column_num, row_block, column_block) + ) + # print(input.shape) + return input + + +def block_reshape(input, origin_input, row_block, column_block, pad_block=False): + if len(origin_input.shape) > 2: + flatten_input = origin_input.reshape(-1, origin_input.shape[2]) + elif len(origin_input.shape) == 2: + flatten_input = origin_input + else: + raise ValueError(f"input shape {input.shape} does not match for block cut") + + M, N = flatten_input.shape[0], flatten_input.shape[1] + + if row_block == -1: + row_block = M + if column_block == -1: + column_block = N + + if pad_block: + row_remainder, col_remainder = M % row_block, N % column_block + if row_remainder: + row_pad = row_block - row_remainder + else: + row_pad = 0 + if col_remainder: + col_pad = column_block - col_remainder + else: + col_pad = 0 + + pad_origin_input = torch.nn.functional.pad(origin_input, (0, col_pad, 0, row_pad), "constant", 0) + M, N = pad_origin_input.shape[0], pad_origin_input.shape[1] + row_num, column_num = M // row_block, N // column_block + else: + row_num, column_num = M // row_block, N // column_block + + input = ( + input.reshape(row_num, column_num, row_block, column_block) + .permute(0, 2, 1, 3) + .reshape(row_num * row_block, column_num * column_block) + ) + + M, N = flatten_input.shape[0], flatten_input.shape[1] + input = input[:M, :N] + + if len(origin_input.shape) > 2: + input = input.reshape(origin_input.shape) + elif len(origin_input.shape) == 2: + pass + else: + raise ValueError(f"input shape {input.shape} does not match for block reshape") + + return input + + +def block_verify_int8(input, row_block, column_block, layer_type, necessary=True): + Binput = block_cut(input, row_block, column_block) + Binput = Binput.to(torch.float32) + + for n in range(Binput.shape[0]): + unique_values = len(torch.unique(Binput[n, :, :])) + if unique_values > 256: + if necessary: + raise ValueError(f"{layer_type} contains more than 256 unique values.") + else: + return False + return True + + +def block_quant(input, symm, bits, stochastic, epsilon, apply_quantize, layer_name): + Quant_fn = SymmQuantizer + return Quant_fn.apply(input, symm, bits, stochastic, epsilon, apply_quantize, layer_name) + + +def extract_bit(string): + match = re.match(r"INT(\d+)", string) # INT8 + if match: + return "integer", int(match.group(1)), None + match = re.match(r"E(\d+)M(\d+)", string) # E4M3 / E5M2 + if match: + Ebit, Mbit = int(match.group(1)), int(match.group(2)) + if Ebit == 1: + return "integer", Mbit + 1, None + if Mbit == 0: + return "floatExM0", int(match.group(1)), 0 + return "floatExMy", int(match.group(1)), int(match.group(2)) + match = re.match(r"DE(\d+)", string) + if match: + return "Dynamic", int(match.group(1)), None + match = re.match(r"ZeroD(\d+)", string) + if match: + return "ZeroDynamic", int(match.group(1)), None + raise ValueError(f"{string} data format is not supported") + + +class SymmQuantizer(torch.autograd.function.InplaceFunction): + @staticmethod + def forward(ctx, input, symm, bits, stochastic, epsilon, apply_quantize=True, layer_name=None): + with torch.no_grad(): + absmax_per_block = input.abs().amax(dim=(1, 2)).unsqueeze(1).unsqueeze(2) + epsilon + + if bits == "100" or not apply_quantize: + return input, input, torch.ones_like(absmax_per_block) + elif bits == "FP32": + return input.to(torch.float32), input.to(torch.float32), torch.ones_like(absmax_per_block) + elif bits == "FP16": + return input.to(torch.float16), input.to(torch.float16), torch.ones_like(absmax_per_block) + elif bits == "BF16": + return input.to(torch.bfloat16), input.to(torch.bfloat16), torch.ones_like(absmax_per_block) + else: + QuantType, bit1, bit2 = extract_bit(bits) + if not symm: + bit1 = bit1 + 1 # pretend to be asymmtric + + if QuantType == "integer": + Qn, Qp = -(2 ** (bit1 - 1) - 1), 2 ** (bit1 - 1) - 1 + elif QuantType == "floatExMy": + Qn, Qp = -(2 - 2 ** (-bit2)) * (2 ** (2 ** (bit1 - 1))), (2 - 2 ** (-bit2)) * ( + 2 ** (2 ** (bit1 - 1)) + ) + if bit1 == 4 and bit2 == 3: # E4M3 + Qn, Qp = -448, 448 + if bit1 == 5 and bit2 == 2: # E5M2 + Qn, Qp = -57344, 57344 + elif QuantType == "floatExM0": + Qn, Qp = -(2 ** (2 ** (bit1 - 1))) + 1, 2 ** (2 ** (bit1 - 1)) + elif QuantType == "Dynamic": + Qn, Qp = -1, 1 + elif QuantType == "ZeroDynamic": + Qn, Qp = -1, 1 + else: + raise NotImplementedError(f"{bits} is not supported by quantization") + scale_per_block = (2 * absmax_per_block) / (Qp - Qn) + scale_per_block = scale_per_block.to(input) + + Qinput = input / scale_per_block + + if QuantType == "integer": + if stochastic: + noise = Qinput.new(Qinput.shape).uniform_(-0.5, 0.5) + Qinput.add_(noise) + Qinput.clamp_(Qn, Qp).round_() + elif QuantType == "floatExMy": + # Qinput = floatExMy_quantize_torch(Qinput, bit1, bit2, stochastic) + Qinput = floatExMy_quantize_triton(Qinput, bit1, bit2, stochastic) + elif QuantType == "floatExM0": + Qinput = floatExM0_quantize_torch(Qinput, bit1, stochastic) + else: + raise NotImplementedError(f"{bits} is not supported by quantization") + + RQinput = Qinput * scale_per_block + + if input.dtype != Qinput.dtype: + print( + f"Input type is {input.dtype}, Qinput type is {Qinput.dtype}, scale_per_block type is {scale_per_block.dtype}", + file=open("debug.txt", "a"), + ) + import IPython + + IPython.embed() + return RQinput, Qinput, scale_per_block + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None, None, None, None, None diff --git a/llava/model/coat/activation/fake_quantization/utils.py b/llava/model/coat/activation/fake_quantization/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bc4ec2dae1ef7be1dc7c82c79439b7091d49feb6 --- /dev/null +++ b/llava/model/coat/activation/fake_quantization/utils.py @@ -0,0 +1,115 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import os + +import matplotlib.pyplot as plt +import numpy as np +import torch + + +def list_has_common_element(list1, list2): + set1 = set(list1) + set2 = set(list2) + return len(set1.intersection(set2)) > 0 + + +def calculate_scale_num(input, row_block, col_block): + if len(input.shape) > 2: + input = input.reshape(-1, input.shape[2]) + elif len(input.shape) == 2: + pass + else: + raise ValueError(f"input shape {input.shape} does not match for block cut, {input}") + M, N = input.shape[0], input.shape[1] + + if row_block == -1: + row_block = M + if col_block == -1: + col_block = N + + return input.numel() / (row_block * col_block) + + +def quant_get_local_rank() -> int: + return int(os.environ.get("LOCAL_RANK") or 0) + + +def format_string_with_condition( + input_string, + condition_config, + symm, + bits, + blocksize_config, + input_pad=20, +): + padded_string = input_string.ljust(input_pad) + output_string = padded_string + + for k, v in condition_config.items(): + if v: + output_string = output_string + k.ljust(10) + "True".ljust(6) + "".ljust(6) + else: + output_string = output_string + k.ljust(10) + "".ljust(6) + "False".ljust(6) + + output_string = output_string + f"Symm {symm}".ljust(10) + + for k, v in bits.items(): + output_string = output_string + f"{k} bit".ljust(10) + v.ljust(10) + for k, v in blocksize_config.items(): + output_string += f"{k}: {v}".ljust(15) + + return output_string + + +def print_warning(sentence): + print("*" * (len(sentence) + 4)) + print(f"* {sentence} *") + print("*" * (len(sentence) + 4)) + + +def check_nan_inf(tensor, check_nan, check_inf): + if check_nan: + contain_nan = torch.isnan(tensor).any() + else: + contain_nan = False + if check_inf: + contain_inf = torch.isinf(tensor).any() + else: + contain_inf = False + return contain_nan, contain_inf + + +def move_torch_to_numpy(tensor): + if tensor is None: + return None + + if tensor.is_cuda: + tensor = tensor.cpu() + return tensor.detach().float().numpy() + + +def flatten_to_1d(tensor): + if tensor is None: + return None + + return tensor.reshape(-1) diff --git a/llava/model/coat/activation/models/_fp8_quantization_config.py b/llava/model/coat/activation/models/_fp8_quantization_config.py new file mode 100644 index 0000000000000000000000000000000000000000..579df908203575abf8da9565995640c4ae459e3b --- /dev/null +++ b/llava/model/coat/activation/models/_fp8_quantization_config.py @@ -0,0 +1,67 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from dataclasses import dataclass + +from transformers import PretrainedConfig + + +@dataclass +class QuantizationConfig: + quantize_model: str = "false" + symm: bool = True + epsilon: float = 1e-10 + fabit: str = "E4M3" + fwbit: str = "E4M3" + fobit: str = "E4M3" + babit: str = "E5M2" + bwbit: str = "E5M2" + bobit: str = "E5M2" + qchoice: str = "none" + group_size: int = -1 + pad_to_multiple_of: int = 0 + weight_memory_efficient: bool = True + + # Legacy + row_blocksize: int = -1 + col_blocksize: int = -1 + + def __init__( + self, + quantize_model: str = "false", + symm: bool = True, + epsilon: float = 1e-10, + fabit: str = "E4M3", + fwbit: str = "E4M3", + fobit: str = "E4M3", + babit: str = "E5M2", + bwbit: str = "E5M2", + bobit: str = "E5M2", + qchoice: str = "none", + group_size: int = -1, + pad_to_multiple_of: int = 0, + weight_memory_efficient: bool = True, + row_blocksize: int = -1, + col_blocksize: int = -1, + **kwargs, + ): + super().__init__() + self.quantize_model = quantize_model + self.symm = symm + self.epsilon = epsilon + self.fabit = fabit + self.fwbit = fwbit + self.fobit = fobit + self.babit = babit + self.bwbit = bwbit + self.bobit = bobit + self.qchoice = qchoice + self.group_size = group_size + self.pad_to_multiple_of = pad_to_multiple_of + self.weight_memory_efficient = weight_memory_efficient + + self.row_blocksize = row_blocksize + self.col_blocksize = col_blocksize diff --git a/llava/model/coat/activation/models/_fp8_weightcache.py b/llava/model/coat/activation/models/_fp8_weightcache.py new file mode 100644 index 0000000000000000000000000000000000000000..85b1c504c74b83299181020e1275133428f93e1d --- /dev/null +++ b/llava/model/coat/activation/models/_fp8_weightcache.py @@ -0,0 +1,48 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +import torch.nn as nn + +from ..real_quantization import fp8_division_transpose + + +class FP8CacheWeightModule(nn.Module): + def __init__(self, config, qargs, layer_id): + super().__init__() + self.config = config + self.qargs = qargs + self.layer_id = layer_id + + def prepare_weight(self, weight, weight_name, is_first_microbatch): + if is_first_microbatch: + if self.qargs.weight_memory_efficient: + # print(f"{weight_name} uses first microbatch") + weight_fp8, weight_s, weight_fp8_t = fp8_division_transpose( + weight, self.qargs.group_size, self.fwobits["fwbit"] + ) + setattr(self, f"{weight_name}_fp8_scale", weight_s) + return weight_fp8, weight_fp8_t, weight_s + else: + # print(f"{weight_name} uses first microbatch") + weight_fp8, weight_s, weight_fp8_t = fp8_division_transpose( + weight, self.qargs.group_size, self.fwobits["fwbit"] + ) + setattr(self, f"{weight_name}_fp8", weight_fp8) + setattr(self, f"{weight_name}_fp8_t", weight_fp8_t) + setattr(self, f"{weight_name}_fp8_scale", weight_s) + return weight_fp8, weight_fp8_t, weight_s + else: + if self.qargs.weight_memory_efficient: + return getattr(self, f"{weight_name}_fp8_scale") + else: + return ( + getattr(self, f"{weight_name}_fp8"), + getattr(self, f"{weight_name}_fp8_t"), + getattr(self, f"{weight_name}_fp8_scale"), + ) + + def forward(self, x): + pass diff --git a/llava/model/coat/activation/models/_fp8manager.py b/llava/model/coat/activation/models/_fp8manager.py new file mode 100644 index 0000000000000000000000000000000000000000..5270acf79adc7d17838959b0c679cba3b43c5bee --- /dev/null +++ b/llava/model/coat/activation/models/_fp8manager.py @@ -0,0 +1,31 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +class FP8Manager: + """Class to keep track of and manipulate the global + FP8 state at different stages of execution. + """ + + is_first_microbatch = False diff --git a/llava/model/coat/activation/models/coat_llama.py b/llava/model/coat/activation/models/coat_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..9b6763c027ac05d04c1b66b483cf886c0e1f7fab --- /dev/null +++ b/llava/model/coat/activation/models/coat_llama.py @@ -0,0 +1,1479 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import os +from fnmatch import fnmatch +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaForCausalLM, + LlamaLinearScalingRotaryEmbedding, + LlamaModel, + LlamaPreTrainedModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, + _prepare_4d_causal_attention_mask_with_cache_position, + apply_rotary_pos_emb, + repeat_kv, + rotate_half, +) +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) + +from ..real_quantization import ( + Coat_quantize_bgn, + Coat_quantize_end, + fp8_add_Ifp_Ifp_Ofp_Og16, + fp8_add_Ifp_Ifp_Ofp_Opt, + fp8_division, + fp8_division_transpose, + fp8_gelu_backward, + fp8_gelu_forward, + fp8_layernorm_noparam_backward, + fp8_layernorm_noparam_forward, + fp8_linear_backward, + fp8_linear_forward, + fp8_mul_backward, + fp8_mul_forward, + fp8_quantize, + fp8_quantize_pertensor, + fp8_quantize_pertensor_transpose, + fp8_rmsnorm_backward, + fp8_rmsnorm_forward, + fp8_silu_backward, + fp8_silu_forward, + fp8_transpose, +) + +# FP8 related +from ._fp8_quantization_config import QuantizationConfig +from ._fp8_weightcache import FP8CacheWeightModule +from ._fp8manager import FP8Manager + +logger = logging.get_logger(__name__) + + +class CoatLlamaConfig(LlamaConfig): + model_type = "fp8_llama" + + +class CoatLlamaBeforeAttentionResidual(FP8CacheWeightModule): + """ + This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 1 * Linear layers + """ + + def __init__(self, config: CoatLlamaConfig, qargs: QuantizationConfig, layer_idx: Optional[int] = None): + super().__init__(config, qargs, layer_idx) + + self.qargs = qargs + self.fwobits = { + "fabit": self.qargs.fabit, + "fwbit": self.qargs.fwbit, + "fobit": self.qargs.fobit, + "babit": self.qargs.babit, + "bwbit": self.qargs.bwbit, + "bobit": self.qargs.bobit, + } + + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + + def forward(self, re_x, x, s, rmsnorm_weight): + if self.training: + if self.qargs.weight_memory_efficient: + # Prepare + with torch.no_grad(): + weight1_s = self.prepare_weight(self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch) + weight2_s = self.prepare_weight(self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch) + weight3_s = self.prepare_weight(self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch) + return _CoatLlamaBeforeAttentionResidual.apply( + re_x, + x, + s, + self.q_proj.weight, + None, + None, + weight1_s, + self.k_proj.weight, + None, + None, + weight2_s, + self.v_proj.weight, + None, + None, + weight3_s, + rmsnorm_weight, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + # Prepare + with torch.no_grad(): + weight1, weight1_t, weight1_s = self.prepare_weight( + self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch + ) + weight2, weight2_t, weight2_s = self.prepare_weight( + self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch + ) + weight3, weight3_t, weight3_s = self.prepare_weight( + self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch + ) + return _CoatLlamaBeforeAttentionResidual.apply( + re_x, + x, + s, + self.q_proj.weight, + weight1, + weight1_t, + weight1_s, + self.k_proj.weight, + weight2, + weight2_t, + weight2_s, + self.v_proj.weight, + weight3, + weight3_t, + weight3_s, + rmsnorm_weight, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + return re_x, self.att_proj(self.attn_norm(re_x)) + + +class _CoatLlamaBeforeAttentionResidual(torch.autograd.Function): + @staticmethod + def forward( + ctx, + re_x, + in_x, + in_s, + weight1_origin, + weight1, + weight1_t, + weight1_s, + weight2_origin, + weight2, + weight2_t, + weight2_s, + weight3_origin, + weight3, + weight3_t, + weight3_s, + rmsnorm_weight, + group_size, + fwobits, + layer_id, + config, + qargs, + eps=1e-5, + ): + # for autograd + if fwobits["fabit"] == "E4M3": + # in_x = in_x.to(torch.float8_e4m3fn) + in_x = in_x.view(torch.float8_e4m3fn) + else: + raise ValueError("fabit should be E4M3") + + # LayerNorm + ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward( + in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True + ) + + # Linear Layer QKV Projection + if qargs.weight_memory_efficient: + assert weight1 is None # memory efficient + weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s) + weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s) + weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s) + + fc1_x = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, False, group_size) # query states + fc2_x = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, False, group_size) # key states + fc3_x = fp8_linear_forward(ln_x, ln_s, weight3, weight3_s, False, group_size) # value states + + # ==================== save for backward ==================== + ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s) + if qargs.weight_memory_efficient: + assert weight1_t is None and weight2_t is None and weight3_t is None + ctx.weight = weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s + else: + ctx.weight = weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s + + ctx.group_size = group_size + ctx.ln_utils = ln_utils + ctx.utils = fwobits, layer_id, config, qargs + + return re_x, fc1_x, fc2_x, fc3_x + + @staticmethod + def backward(ctx, fp_grad, query_g, key_g, value_g): + in_x, in_s, ln_x_t, ln_s = ctx.saved_tensors + weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s = ctx.weight + + group_size = ctx.group_size + rms_weight, rstd, num_warps = ctx.ln_utils + fwobits, layer_id, config, qargs = ctx.utils + + # ==================== Begin backward ==================== + # Quantize the RoPE and FlashAttention Output. grad_input and grad_weight requires different data layout. + query_g, query_gs, query_g_t = fp8_quantize_pertensor_transpose( + query_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False + ) + key_g, key_gs, key_g_t = fp8_quantize_pertensor_transpose( + key_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False + ) + value_g, value_gs, value_g_t = fp8_quantize_pertensor_transpose( + value_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False + ) + + # Linear Layer QKV Projection + if qargs.weight_memory_efficient: + weight1_t, weight1_s = fp8_division_transpose( + weight1_t, qargs.group_size, fwobits["fwbit"], weight1_s, only_transposed=True + ) + weight2_t, weight2_s = fp8_division_transpose( + weight2_t, qargs.group_size, fwobits["fwbit"], weight2_s, only_transposed=True + ) + weight3_t, weight3_s = fp8_division_transpose( + weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True + ) + + fc1_g1, att_q_wg = fp8_linear_backward( + ln_x_t, ln_s, query_g, query_gs, query_g_t, weight1_t, weight1_s, group_size + ) + fc1_g2, att_k_wg = fp8_linear_backward(ln_x_t, ln_s, key_g, key_gs, key_g_t, weight2_t, weight2_s, group_size) + fc1_g3, att_v_wg = fp8_linear_backward( + ln_x_t, ln_s, value_g, value_gs, value_g_t, weight3_t, weight3_s, group_size + ) + + fc1_g = fc1_g1 + fc1_g2 + fc1_g3 + + # LayerNorm + in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc1_g, rms_weight, rstd, group_size, num_warps) + + # Add the gradient together, and prepare the input of the next layer. + re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt( + fp_grad, in_g, group_size, fwobits["babit"], stochastic=False + ) + + # for autograd. forward's data type should be the same of backward tensor. this will not change the actual binary representation. + in_g = in_g.view(torch.float8_e4m3fn) + + # Although the next operator is a linear layer in MLPResidual module, we return in_sg_g16 to make the size compatible with the forward. Otherwise it will not pass autograd. + return ( + re_g, + in_g, + in_sg_g16, + att_q_wg, + None, + None, + None, + att_k_wg, + None, + None, + None, + att_v_wg, + None, + None, + None, + rms_weight_grad, + None, + None, + None, + None, + None, + None, + ) + + +class CoatLlamaAfterAttentionResidual(FP8CacheWeightModule): + """ + This is a typical transformer attention module that contains (1) Residual (2) 1 * Linear layers + """ + + def __init__(self, config: CoatLlamaConfig, qargs: QuantizationConfig, layer_id): + super().__init__(config, qargs, layer_id) + + self.qargs = qargs + self.fwobits = { + "fabit": self.qargs.fabit, + "fwbit": self.qargs.fwbit, + "fobit": self.qargs.fobit, + "babit": self.qargs.babit, + "bwbit": self.qargs.bwbit, + "bobit": self.qargs.bobit, + } + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + + def forward(self, re_x, in_x): + if self.training: + if self.qargs.weight_memory_efficient: + # prepare for the weight + with torch.no_grad(): + weight4_s = self.prepare_weight(self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch) + + return _CoatLlamaAfterAttentionResidual.apply( + re_x, + in_x, + self.o_proj.weight, + None, + None, + weight4_s, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + # prepare for the weight + with torch.no_grad(): + weight4, weight4_t, weight4_s = self.prepare_weight( + self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch + ) + + return _CoatLlamaAfterAttentionResidual.apply( + re_x, + in_x, + self.o_proj.weight, + weight4, + weight4_t, + weight4_s, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + return re_x + self.attn_out(in_x), None, None + + +class _CoatLlamaAfterAttentionResidual(torch.autograd.Function): + @staticmethod + def forward( + ctx, re_x, flash_x, weight4_origin, weight4, weight4_t, weight4_s, group_size, fwobits, layer_id, config, qargs + ): + # Quantize the FlashAttention Output + flash_qx, flash_s, _ = fp8_quantize_pertensor( + flash_x, group_size, fwobits["fabit"] + ) # Modified to make it memory efficient + + # # Attention Projection Linear Layer + if qargs.weight_memory_efficient: + assert weight4 is None # memory efficient + weight4, weight4_s = fp8_division(weight4_origin, qargs.group_size, fwobits["fwbit"], weight4_s) + fc4_x = fp8_linear_forward(flash_qx, flash_s, weight4, weight4_s, False, group_size) # + + # import IPython + # IPython.embed() + # Add the activations together + fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc4_x, flash_qx.dtype, group_size) + + # ==================== save for backward ==================== + ctx.save_for_backward(flash_x, flash_s) + if qargs.weight_memory_efficient: + assert weight4_t is None + ctx.weight = weight4_origin, weight4_s + else: + ctx.weight = weight4_t, weight4_s + ctx.group_size = group_size + ctx.fwobits = fwobits + ctx.utils = fwobits, layer_id, config, qargs + + # For autograd + out_x = out_x.view(torch.float8_e4m3fn) + + return fp_x, out_x, out_s + + @staticmethod + def backward(ctx, fp_grad, out_g, out_gs): + flash_x, flash_s = ctx.saved_tensors + weight4_t, weight4_s = ctx.weight + group_size = ctx.group_size + fwobits = ctx.fwobits + fwobits, layer_id, config, qargs = ctx.utils + + # for autograd + if fwobits["babit"] == "E5M2": + # out_g = out_g.to(torch.float8_e5m2) + out_g = out_g.view(torch.float8_e5m2) + else: + raise ValueError("babit should be E5M2") + out_gs_max = out_gs.max() + + # ==================== Begin backward ==================== + # Output Projection + out_g_t = fp8_transpose(out_g, transpose_output_2d=True) + + # We do not save an extra flash_x to save the memory usage + flash_x_t, flash_s = fp8_division_transpose( + flash_x, group_size, fwobits["fabit"], flash_s, stochastic=False, only_transposed=True + ) + + if qargs.weight_memory_efficient: + weight4_t, weight4_s = fp8_division_transpose( + weight4_t, qargs.group_size, fwobits["fwbit"], weight4_s, only_transposed=True + ) + fc4_g, attn_out_wg = fp8_linear_backward( + flash_x_t, flash_s, out_g, out_gs_max, out_g_t, weight4_t, weight4_s, group_size + ) + + return fp_grad, fc4_g, attn_out_wg, None, None, None, None, None, None, None, None + + +class CoatLlamaMLPResidual(FP8CacheWeightModule): + """ + This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 2 / 3 * Linear layers + (4) GELU / Silu Activation + """ + + def __init__(self, config: CoatLlamaConfig, qargs: QuantizationConfig, layer_id, hidden_size: int): + super().__init__(config, qargs, layer_id) + + self.qargs = qargs + self.fwobits = { + "fabit": self.qargs.fabit, + "fwbit": self.qargs.fwbit, + "fobit": self.qargs.fobit, + "babit": self.qargs.babit, + "bwbit": self.qargs.bwbit, + "bobit": self.qargs.bobit, + } + + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.training = True + + # below is only used when training = False + assert config.hidden_act == "silu", "We only support silu activation currently" + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, re_x, x, s, rmsnorm_weight): + if self.training: + if self.qargs.weight_memory_efficient: # prepare for the weight + with torch.no_grad(): + weight1_s = self.prepare_weight(self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch) + weight2_s = self.prepare_weight(self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch) + weight3_s = self.prepare_weight(self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch) + + return _CoatLlamaMLPResidual.apply( + re_x, + x, + s, + self.gate_proj.weight, + None, + None, + weight1_s, + self.up_proj.weight, + None, + None, + weight2_s, + self.down_proj.weight, + None, + None, + weight3_s, + rmsnorm_weight, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + # prepare for the weight + with torch.no_grad(): + weight1, weight1_t, weight1_s = self.prepare_weight( + self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch + ) + weight2, weight2_t, weight2_s = self.prepare_weight( + self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch + ) + weight3, weight3_t, weight3_s = self.prepare_weight( + self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch + ) + + return _CoatLlamaMLPResidual.apply( + re_x, + x, + s, + self.gate_proj.weight, + weight1, + weight1_t, + weight1_s, + self.up_proj.weight, + weight2, + weight2_t, + weight2_s, + self.down_proj.weight, + weight3, + weight3_t, + weight3_s, + rmsnorm_weight, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + raise NotImplementedError("Need TODO") + og_x = re_x + re_x = self.ff_norm(re_x) + re_x = self.ff_proj(re_x) + re_x = self.act(re_x) + re_x = self.ff_out(re_x) + re_x = og_x + re_x + return re_x, None, None + + +class _CoatLlamaMLPResidual(torch.autograd.Function): + @staticmethod + def forward( + ctx, + re_x, + in_x, + in_s, + weight1_origin, + weight1, + weight1_t, + weight1_s, + weight2_origin, + weight2, + weight2_t, + weight2_s, + weight3_origin, + weight3, + weight3_t, + weight3_s, + rmsnorm_weight, + group_size, + fwobits, + layer_id, + config, + qargs, + eps=1e-5, + ): + # For autograd + if fwobits["fabit"] == "E4M3": + # in_x = in_x.to(torch.float8_e4m3fn) + in_x = in_x.view(torch.float8_e4m3fn) + else: + raise ValueError("fabit should be E4M3") + + # LayerNorm + ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward( + in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True + ) + + # Linear Layer of Up Projection and Gate Projection. They are fused as one linear layer. + if qargs.weight_memory_efficient: + assert weight1 is None and weight2 is None and weight3 is None # memory efficient + weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s) + weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s) + weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s) + + gate_x, gate_s = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, True, group_size) # Gate Proj + up_x, up_s = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, True, group_size) # Up Proj + + # silu Activation + silu_x, silu_s = fp8_silu_forward(gate_x, gate_s, group_size) + + # Element-wise Multiplication + mul_x, mul_s, mul_x_t = fp8_mul_forward(silu_x, silu_s, up_x, up_s, group_size, transpose_output_2d=True) + + # Output Projection + if weight3 is None: # memory efficient + weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s) + fc3_x = fp8_linear_forward(mul_x, mul_s, weight3, weight3_s, False, group_size) + + # Add the activation together + fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc3_x, mul_x.dtype, group_size) + + # ==================== save for backward ==================== + ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s) + + ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s) + if ( + qargs.weight_memory_efficient + ): # Weight_1/2_origin will not be saved twice, so it will be more memory efficient. + assert weight1_t is None and weight2_t is None and weight3_t is None + ctx.weight = (weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s) + else: # Weight1/2_t is different from the origin weight, so saving it will consumes additional memory footprint. + ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s) + + ctx.group_size = group_size + ctx.ln_utils = ln_utils + ctx.utils = fwobits, layer_id, config, qargs + + out_x = out_x.view(torch.float8_e4m3fn) + + return fp_x, out_x, out_s + + @staticmethod + def backward(ctx, fp_grad, out_g, out_gs): + fwobits, layer_id, config, qargs = ctx.utils + + in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s = ctx.saved_tensors + + (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s) = ctx.weight + group_size = ctx.group_size + rms_weight, rstd, num_warps = ctx.ln_utils + fwobits, layer_id, config, qargs = ctx.utils + + # For autograd + if fwobits["babit"] == "E5M2": + # out_g = out_g.to(torch.float8_e5m2) + out_g = out_g.view(torch.float8_e5m2) + else: + raise ValueError("babit should be E5M2") + out_gs_max = out_gs.max() + + # ==================== Begin backward ==================== + # Output Projection + out_gs = out_gs.max() + out_g_t = fp8_transpose(out_g, transpose_output_2d=True) + + if qargs.weight_memory_efficient: + weight3_t, weight3_s = fp8_division_transpose( + weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True + ) + fc3_g, weight3_grad = fp8_linear_backward( + mul_x_t, mul_s, out_g, out_gs_max, out_g_t, weight3_t, weight3_s, group_size + ) + + # [MEM TEST] + del out_g, out_g_t, weight3_t + + # Element-wise Multiplication, 1 means gate, 2 means up + mul_g1, (mul_g2, mul_gs2, mul_g2_t) = fp8_mul_backward( + silu_x, silu_s, up_x, up_s, fc3_g, group_size, fwobits["babit"], output_quantized_transpose=True + ) + + # Silu activation + silu_g, silu_gs, silu_g_t = fp8_silu_backward( + gate_x, gate_s, mul_g1, group_size, fwobits["babit"], output_quantized_transpose=True + ) + + # Linear Layer of Up and Gate Projection + if qargs.weight_memory_efficient: + weight1_t, weight1_s = fp8_division_transpose( + weight1_t, group_size, fwobits["fwbit"], weight1_s, only_transposed=True + ) + weight2_t, weight2_s = fp8_division_transpose( + weight2_t, group_size, fwobits["fwbit"], weight2_s, only_transposed=True + ) + + # Gate Proj + fc1_g, weight1_grad = fp8_linear_backward( + ln_x_t, ln_s, silu_g, silu_gs, silu_g_t, weight1_t, weight1_s, group_size + ) + fc2_g, weight2_grad = fp8_linear_backward( + ln_x_t, ln_s, mul_g2, mul_gs2, mul_g2_t, weight2_t, weight2_s, group_size + ) + + fc_g = fc1_g + fc2_g + + # layerNorm + in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc_g, rms_weight, rstd, group_size, num_warps) + + # Add the gradient together + re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt( + fp_grad, in_g, group_size, fwobits["babit"], stochastic=False + ) + + in_g = in_g.view(torch.float8_e4m3fn) + + return ( + re_g, + in_g, + in_sg_g16, + weight1_grad, + None, + None, + None, + weight2_grad, + None, + None, + None, + weight3_grad, + None, + None, + None, + rms_weight_grad, + None, + None, + None, + None, + None, + None, + ) + + +class LlamaAttentionWithoutLinear(nn.Module): + """ + Remove the Q/K/V/O projection layer in LlamaAttention module and only calculate the attention logic. + The Q/K/V Projection is moved to BeforeAttention Module, and the O Projection is moved to AfterAttention Module. + """ + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) + self.rotary_emb = LlamaRotaryEmbedding(config=self.config) + + def forward( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2WithoutLinear(LlamaAttentionWithoutLinear): + """ + Remove the Q/K/V/O projection layer in LlamaFlashAttention2 module and only calculate the attention logic. + The Q/K/V Projection is moved to BeforeAttention Module, and the O Projection is moved to AfterAttention Module. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = query_states.size() + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaSdpaAttentionWithoutLinear(LlamaAttentionWithoutLinear): + """ + Remove the Q/K/V/O projection layer in LlamaSdpaAttention module and only calculate the attention logic. + The Q/K/V Projection is moved to BeforeAttention Module, and the O Projection is moved to AfterAttention Module. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + query_states=query_states, + key_states=key_states, + value_states=value_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + return attn_output, None, past_key_value + + +COAT_LLAMA_ATTENTION_CLASSES = { + "eager": LlamaAttentionWithoutLinear, + "flash_attention_2": LlamaFlashAttention2WithoutLinear, + "sdpa": LlamaSdpaAttentionWithoutLinear, +} + + +class CoatLlamaDecoderLayer(nn.Module): + def __init__(self, config: CoatLlamaConfig, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + + self.self_attn = COAT_LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.qargs = QuantizationConfig(**config.coat_fp8_args) + self.BeforeAttention = CoatLlamaBeforeAttentionResidual(config, self.qargs, layer_idx) + self.AfterAttention = CoatLlamaAfterAttentionResidual(config, self.qargs, layer_idx) + self.MLPResidual = CoatLlamaMLPResidual(config, self.qargs, layer_idx, self.hidden_size) + + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + quant_hidden_states: torch.Tensor, + scale_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): BF16 input to the layer of shape `(batch, seq_len, embed_dim)` + quant_hidden_states (`torch.float8_e4m3fn`): FP8 input to the layer of shape `(batch, seq_len, embed_dim)` + scale_hidden_states (`torch.bfloat16`): BF16 scaling factor to the layer of shape `(batch, seq_len, embed_dim // group_size)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + # Coat: The residual, LayerNorm, and the Q/K/V Projection Linear Layer + residual, query_states, key_states, value_states = self.BeforeAttention( + hidden_states, quant_hidden_states, scale_hidden_states, self.input_layernorm.weight + ) + + # Self Attention without any linear layer + hidden_states, self_attn_weights, present_key_value = self.self_attn( + query_states=query_states, + key_states=key_states, + value_states=value_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + # Coat: The Output Projection Linear Layer and Residual + hidden_states, quant_hidden_states, scale_hidden_states = self.AfterAttention(residual, hidden_states) + + # Residual Connection, LayerNorm, and the whole MLP module + hidden_states, quant_hidden_states, scale_hidden_states = self.MLPResidual( + hidden_states, quant_hidden_states, scale_hidden_states, self.post_attention_layernorm.weight + ) + + outputs = ((hidden_states, quant_hidden_states, scale_hidden_states),) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class CoatLlamaPreTrainedModel(PreTrainedModel): + config_class = CoatLlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class CoatLlamaModel(CoatLlamaPreTrainedModel): + """ + Coat Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`CoatLlamaDecoderLayer`] + + Args: + config: CoatLlamaConfig + """ + + def __init__(self, config: CoatLlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [CoatLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Quantize + self.qargs = QuantizationConfig(**config.coat_fp8_args) + self.quantize_input_before_block = Coat_quantize_bgn(self.qargs) + self.quantize_output_after_block = Coat_quantize_end(self.qargs) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + # Prepare the input for Coat decoderlayer + hidden_states, quant_hidden_states, scale_hidden_states = self.quantize_input_before_block(hidden_states) + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + quant_hidden_states, + scale_hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + quant_hidden_states, + scale_hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states, quant_hidden_states, scale_hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # Summarize the output of the Decoder Layer + hidden_states = self.quantize_output_after_block(hidden_states, quant_hidden_states, scale_hidden_states) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + _update_causal_mask = LlamaModel._update_causal_mask + + +class CoatLlamaForCausalLM(CoatLlamaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = CoatLlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + forward = LlamaForCausalLM.forward + + prepare_inputs_for_generation = LlamaForCausalLM.prepare_inputs_for_generation + + +# TODO +# class LlamaForSequenceClassification(LlamaPreTrainedModel): + +# class LlamaForQuestionAnswering(LlamaPreTrainedModel): + +# class LlamaForTokenClassification(LlamaPreTrainedModel): + + +def make_state_dict_compatible(state_dict: dict[str, torch.Tensor]): + compatible_state_dict = {} + + for key, value in state_dict.items(): + if fnmatch(key, "*self_attn.q_proj*"): + new_key = key.replace("self_attn.q_proj", "BeforeAttention.q_proj") + elif fnmatch(key, "*self_attn.k_proj*"): + new_key = key.replace("self_attn.k_proj", "BeforeAttention.k_proj") + elif fnmatch(key, "*self_attn.v_proj*"): + new_key = key.replace("self_attn.v_proj", "BeforeAttention.v_proj") + elif fnmatch(key, "*self_attn.o_proj*"): + new_key = key.replace("self_attn.o_proj", "AfterAttention.o_proj") + + elif fnmatch(key, "*mlp.gate_proj*"): + new_key = key.replace("mlp.gate_proj", "MLPResidual.gate_proj") + elif fnmatch(key, "*mlp.up_proj*"): + new_key = key.replace("mlp.up_proj", "MLPResidual.up_proj") + elif fnmatch(key, "*mlp.down_proj*"): + new_key = key.replace("mlp.down_proj", "MLPResidual.down_proj") + + else: + new_key = key + + compatible_state_dict[new_key] = value + + return compatible_state_dict + + +AutoConfig.register("fp8_llama", CoatLlamaConfig) +AutoModel.register(CoatLlamaConfig, CoatLlamaModel) +AutoModelForCausalLM.register(CoatLlamaConfig, CoatLlamaForCausalLM) diff --git a/llava/model/coat/activation/models/coat_llama_convert_from_hf.py b/llava/model/coat/activation/models/coat_llama_convert_from_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..c02da9e1b4211f8becaa2a54dc30a46fcf88185b --- /dev/null +++ b/llava/model/coat/activation/models/coat_llama_convert_from_hf.py @@ -0,0 +1,71 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +import argparse +import os +from dataclasses import asdict, dataclass, field +from typing import Optional + +import torch +import transformers +from coat.activation.models._fp8_quantization_config import QuantizationConfig +from coat.activation.models.coat_llama import CoatLlamaConfig, CoatLlamaForCausalLM, make_state_dict_compatible +from transformers import AutoConfig, AutoModelForCausalLM + + +@dataclass +class ConvertArguments: + model_name: str = field(metadata={"help": "The model name or path to download the LLaMA model"}) + save_path: str = field(metadata={"help": "The path where the converted model weights will be saved"}) + cache_dir: str = field(default=None, metadata={"help": "Directory to cache the model"}) + + +def download_and_convert_llama(convert_args: ConvertArguments, quantization_args: QuantizationConfig): + """ + Downloads a LLaMA model, converts its weights using `make_state_dict_compatible`, + and saves the converted model. + + Args: + model_name (str): The model name or path to download the LLaMA model. + save_path (str): The path where the converted model weights will be saved. + cache_dir (Optional[str]): Directory to cache the model. Defaults to None. + + Returns: + None + """ + model_name = convert_args.model_name + save_path = convert_args.save_path + cache_dir = convert_args.cache_dir + + # Step 1: Download the original LLaMA model + model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) + + # Step 2: Initialize the model configuration for FP8 or other custom config + config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir) + + # Step 3: Apply make_state_dict_compatible to convert weights + compatible_state_dict = make_state_dict_compatible(model.state_dict()) + + # Step 4: Create a new model instance with compatible configuration + fp8_config = CoatLlamaConfig(**config.to_dict()) + fp8_config.coat_fp8_args = asdict(quantization_args) + + converted_model = AutoModelForCausalLM.from_config(fp8_config) + converted_model.load_state_dict(compatible_state_dict) + + # Step 5: Save the converted model and configuration using save_pretrained + os.makedirs(save_path, exist_ok=True) + converted_model.save_pretrained(save_path) + print(f"Converted model saved at {save_path}") + + +if __name__ == "__main__": + # Parse command-line arguments + parser = transformers.HfArgumentParser((ConvertArguments, QuantizationConfig)) # NOTE: FP8 + convert_args, quantization_args = parser.parse_args_into_dataclasses() + + # Call the function with parsed arguments + download_and_convert_llama(convert_args, quantization_args) diff --git a/llava/model/coat/activation/models/coat_olmo.py b/llava/model/coat/activation/models/coat_olmo.py new file mode 100644 index 0000000000000000000000000000000000000000..d5d2cdedd73c56a68396aebdab1a2762876abc53 --- /dev/null +++ b/llava/model/coat/activation/models/coat_olmo.py @@ -0,0 +1,1942 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Adapted from +[MosaiclML](https://github.com/mosaicml/examples.git) and +[minGPT](https://github.com/karpathy/minGPT.git) +""" + +from __future__ import annotations + +import logging +import math +import sys +from abc import abstractmethod +from collections import defaultdict +from functools import partial +from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Sequence, Set, Tuple, cast + +import torch +import torch.backends.cuda +import torch.nn as nn +import torch.nn.functional as F +from olmo.aliases import PathOrStr +from olmo.beam_search import BeamSearch, Constraint, FinalSequenceScorer, Sampler +from olmo.config import ( + ActivationCheckpointingStrategy, + ActivationType, + BlockType, + CheckpointType, + FSDPWrapStrategy, + InitFnType, + LayerNormType, + ModelConfig, + QuantActivationConfig, + ShardedCheckpointerType, + TrainConfig, +) +from olmo.exceptions import OLMoConfigurationError +from olmo.initialization import init_normal +from olmo.model import ( + Activation, + BufferCache, + Dropout, + LayerNorm, + LayerNormBase, + OLMo, + OLMoBlock, + OLMoBlockGroup, + OLMoGenerateOutput, + OLMoOutput, + RMSLayerNorm, + RotaryEmbedding, + _non_meta_init_device, + activation_checkpoint_function, + alibi_attention_bias, + causal_attention_bias, + get_causal_attention_bias, + should_checkpoint_block, +) +from olmo.torch_util import ensure_finite_, get_cumulative_document_lengths +from torch import einsum + +from ..real_quantization import ( + Coat_quantize_bgn, + Coat_quantize_end, + fp8_add_Ifp_Ifp_Ofp_Og16, + fp8_add_Ifp_Ifp_Ofp_Opt, + fp8_division, + fp8_division_transpose, + fp8_gelu_backward, + fp8_gelu_forward, + fp8_layernorm_noparam_backward, + fp8_layernorm_noparam_forward, + fp8_linear_backward, + fp8_linear_forward, + fp8_mul_backward, + fp8_mul_forward, + fp8_quantize, + fp8_quantize_pertensor, + fp8_quantize_pertensor_transpose, + fp8_rmsnorm_backward, + fp8_rmsnorm_forward, + fp8_silu_backward, + fp8_silu_forward, + fp8_transpose, +) +from ._fp8_weightcache import FP8CacheWeightModule +from ._fp8manager import FP8Manager + +if sys.version_info.minor > 8: + from collections.abc import MutableMapping +elif sys.version_info.minor == 8: + from typing import MutableMapping +else: + raise SystemExit("This script supports Python 3.8 or higher") + +__all__ = [ + "LayerNormBase", + "LayerNorm", + "RMSLayerNorm", + "RotaryEmbedding", + "Activation", + "GELU", + "ReLU", + "SwiGLU", + "OLMoBlock", + "OLMoSequentialBlock", + "OLMo", + "OLMoOutput", + "OLMoGenerateOutput", +] + + +log = logging.getLogger(__name__) + + +class CoatOLMoBeforeAttentionResidual(FP8CacheWeightModule): + """ + This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 1 * Linear layers + """ + + def __init__(self, config: ModelConfig, qargs: QuantActivationConfig, layer_id, fused_dims: tuple): + super().__init__(config, qargs, layer_id) + + self.qargs = qargs + self.fwobits = { + "fabit": self.qargs.fabit, + "fwbit": self.qargs.fwbit, + "fobit": self.qargs.fobit, + "babit": self.qargs.babit, + "bwbit": self.qargs.bwbit, + "bobit": self.qargs.bobit, + } + self.ln_normalized_shape = config.d_model + self.att_proj = nn.Linear(config.d_model, sum(fused_dims), bias=config.include_bias, device=config.init_device) + + self.attn_norm = LayerNorm.build(config) + + def forward(self, re_x, x, s): + if self.training: + if self.qargs.weight_memory_efficient: + # Prepare + with torch.no_grad(): + weight1_s = self.prepare_weight(self.att_proj.weight, "att_proj", FP8Manager.is_first_microbatch) + return _CoatOLMoBeforeAttentionResidual.apply( + re_x, + x, + s, + self.att_proj.weight, + None, + None, + weight1_s, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + # Prepare + with torch.no_grad(): + weight1, weight1_t, weight1_s = self.prepare_weight( + self.att_proj.weight, "att_proj", FP8Manager.is_first_microbatch + ) + return _CoatOLMoBeforeAttentionResidual.apply( + re_x, + x, + s, + self.att_proj.weight, + weight1, + weight1_t, + weight1_s, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + return re_x, self.att_proj(self.attn_norm(re_x)) + + +class _CoatOLMoBeforeAttentionResidual(torch.autograd.Function): + @staticmethod + def forward( + ctx, + re_x, + in_x, + in_s, + weight1_origin, + weight1, + weight1_t, + weight1_s, + group_size, + fwobits, + layer_id, + config, + qargs, + eps=1e-5, + ): + # for autograd + if fwobits["fabit"] == "E4M3": + # in_x = in_x.to(torch.float8_e4m3fn) + in_x = in_x.view(torch.float8_e4m3fn) + else: + raise ValueError("fabit should be E4M3") + + # LayerNorm + ln_x, ln_s, ln_x_t, ln_utils = fp8_layernorm_noparam_forward( + in_x, in_s, group_size, eps, transpose_output_2d=True + ) + + # Linear Layer QKV Projection + if qargs.weight_memory_efficient: + assert weight1 is None # memory efficient + weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s) + fc1_x = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, False, group_size) + + # ==================== save for backward ==================== + ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s) + if qargs.weight_memory_efficient: + assert weight1_t is None + ctx.weight = weight1_origin, weight1_s + else: + ctx.weight = weight1_t, weight1_s + ctx.group_size = group_size + ctx.ln_utils = ln_utils + ctx.utils = fwobits, layer_id, config, qargs + + return re_x, fc1_x + + @staticmethod + def backward(ctx, fp_grad, flash_g): + in_x, in_s, ln_x_t, ln_s = ctx.saved_tensors + weight1_t, weight1_s = ctx.weight + group_size = ctx.group_size + mean, rstd, num_warps = ctx.ln_utils + fwobits, layer_id, config, qargs = ctx.utils + + # ==================== Begin backward ==================== + # Quantize the RoPE and FlashAttention Output. grad_input and grad_weight requires different data layout. + flash_g, flash_gs, flash_g_t = fp8_quantize_pertensor_transpose( + flash_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False + ) + + # Linear Layer QKV Projection + if qargs.weight_memory_efficient: + weight1_t, weight1_s = fp8_division_transpose( + weight1_t, qargs.group_size, fwobits["fwbit"], weight1_s, only_transposed=True + ) + fc1_g, att_proj_wg = fp8_linear_backward( + ln_x_t, ln_s, flash_g, flash_gs, flash_g_t, weight1_t, weight1_s, group_size + ) + + # LayerNorm + in_g = fp8_layernorm_noparam_backward(in_x, in_s, fc1_g, group_size, mean, rstd, num_warps) + + # Add the gradient together, and prepare the input of the next layer. + re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt( + fp_grad, in_g, group_size, fwobits["babit"], stochastic=False + ) + + # for autograd. forward's data type should be the same of backward tensor. this will not change the actual binary representation. + in_g = in_g.view(torch.float8_e4m3fn) + + # Although the next operator is a linear layer in MLPResidual module, we return in_sg_g16 to make the size compatible with the forward. Otherwise it will not pass autograd. + return re_g, in_g, in_sg_g16, att_proj_wg, None, None, None, None, None, None, None, None, None + + +class CoatOLMoAfterAttentionResidual(FP8CacheWeightModule): + """ + This is a typical transformer attention module that contains (1) Residual (2) 1 * Linear layers + """ + + def __init__(self, config: ModelConfig, qargs: QuantActivationConfig, layer_id): + super().__init__(config, qargs, layer_id) + + self.qargs = qargs + self.fwobits = { + "fabit": self.qargs.fabit, + "fwbit": self.qargs.fwbit, + "fobit": self.qargs.fobit, + "babit": self.qargs.babit, + "bwbit": self.qargs.bwbit, + "bobit": self.qargs.bobit, + } + self.attn_out = nn.Linear(config.d_model, config.d_model, bias=config.include_bias, device=config.init_device) + + def forward(self, re_x, in_x): + if self.training: + if self.qargs.weight_memory_efficient: + # prepare for the weight + with torch.no_grad(): + weight2_s = self.prepare_weight(self.attn_out.weight, "attn_out", FP8Manager.is_first_microbatch) + + return _CoatOLMoAfterAttentionResidual.apply( + re_x, + in_x, + self.attn_out.weight, + None, + None, + weight2_s, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + # prepare for the weight + with torch.no_grad(): + weight2, weight2_t, weight2_s = self.prepare_weight( + self.attn_out.weight, "attn_out", FP8Manager.is_first_microbatch + ) + + return _CoatOLMoAfterAttentionResidual.apply( + re_x, + in_x, + self.attn_out.weight, + weight2, + weight2_t, + weight2_s, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + return re_x + self.attn_out(in_x), None, None + + +class _CoatOLMoAfterAttentionResidual(torch.autograd.Function): + @staticmethod + def forward( + ctx, re_x, flash_x, weight2_origin, weight2, weight2_t, weight2_s, group_size, fwobits, layer_id, config, qargs + ): + # Quantize the FlashAttention Output + flash_qx, flash_s, _ = fp8_quantize_pertensor( + flash_x, group_size, fwobits["fabit"] + ) # Modified to make it memory efficient + + # # Attention Projection Linear Layer + if qargs.weight_memory_efficient: + assert weight2 is None # memory efficient + weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s) + fc2_x = fp8_linear_forward(flash_qx, flash_s, weight2, weight2_s, False, group_size) # + + # import IPython + # IPython.embed() + # Add the activations together + fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc2_x, flash_qx.dtype, group_size) + + # ==================== save for backward ==================== + ctx.save_for_backward(flash_x, flash_s) + if qargs.weight_memory_efficient: + assert weight2_t is None + ctx.weight = weight2_origin, weight2_s + else: + ctx.weight = weight2_t, weight2_s + ctx.group_size = group_size + ctx.fwobits = fwobits + ctx.utils = fwobits, layer_id, config, qargs + + # For autograd + out_x = out_x.view(torch.float8_e4m3fn) + + return fp_x, out_x, out_s + + @staticmethod + def backward(ctx, fp_grad, out_g, out_gs): + flash_x, flash_s = ctx.saved_tensors + weight2_t, weight2_s = ctx.weight + group_size = ctx.group_size + fwobits = ctx.fwobits + fwobits, layer_id, config, qargs = ctx.utils + + # for autograd + if fwobits["babit"] == "E5M2": + # out_g = out_g.to(torch.float8_e5m2) + out_g = out_g.view(torch.float8_e5m2) + else: + raise ValueError("babit should be E5M2") + out_gs_max = out_gs.max() + + # ==================== Begin backward ==================== + # Output Projection + out_g_t = fp8_transpose(out_g, transpose_output_2d=True) + + # We do not save an extra flash_x to save the memory usage + flash_x_t, flash_s = fp8_division_transpose( + flash_x, group_size, fwobits["fabit"], flash_s, stochastic=False, only_transposed=True + ) + + if qargs.weight_memory_efficient: + weight2_t, weight2_s = fp8_division_transpose( + weight2_t, qargs.group_size, fwobits["fwbit"], weight2_s, only_transposed=True + ) + fc2_g, attn_out_wg = fp8_linear_backward( + flash_x_t, flash_s, out_g, out_gs_max, out_g_t, weight2_t, weight2_s, group_size + ) + + return fp_grad, fc2_g, attn_out_wg, None, None, None, None, None, None, None, None + + +class CoatOLMoMLPResidual(FP8CacheWeightModule): + """ + This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 2 / 3 * Linear layers + (4) GELU / Silu Activation + """ + + def __init__(self, config: ModelConfig, qargs: QuantActivationConfig, layer_id, hidden_size: int): + super().__init__(config, qargs, layer_id) + + self.qargs = qargs + self.fwobits = { + "fabit": self.qargs.fabit, + "fwbit": self.qargs.fwbit, + "fobit": self.qargs.fobit, + "babit": self.qargs.babit, + "bwbit": self.qargs.bwbit, + "bobit": self.qargs.bobit, + } + self.ln_normalized_shape = config.d_model + self.act_output_multiplier = 0.5 if config.activation_type == ActivationType.swiglu else 1 + self.ff_proj = nn.Linear(config.d_model, hidden_size, bias=config.include_bias, device=config.init_device) + self.ff_out = nn.Linear( + int(self.act_output_multiplier * hidden_size), + config.d_model, + bias=config.include_bias, + device=config.init_device, + ) + self.training = True + + # below is only used when training = False + self.ff_norm = LayerNorm.build(config) + self.act = Activation.build(config) + assert (self.act.output_multiplier * hidden_size) % 1 == 0 + + def forward(self, re_x, x, s): + if self.training: + if self.qargs.weight_memory_efficient: # prepare for the weight + with torch.no_grad(): + weight1_s = self.prepare_weight(self.ff_proj.weight, "ff_proj", FP8Manager.is_first_microbatch) + weight2_s = self.prepare_weight(self.ff_out.weight, "ff_out", FP8Manager.is_first_microbatch) + + return _CoatOLMoMLPResidual.apply( + re_x, + x, + s, + self.ff_proj.weight, + None, + None, + weight1_s, + self.ff_out.weight, + None, + None, + weight2_s, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + # prepare for the weight + with torch.no_grad(): + weight1, weight1_t, weight1_s = self.prepare_weight( + self.ff_proj.weight, "ff_proj", FP8Manager.is_first_microbatch + ) + weight2, weight2_t, weight2_s = self.prepare_weight( + self.ff_out.weight, "ff_out", FP8Manager.is_first_microbatch + ) + + return _CoatOLMoMLPResidual.apply( + re_x, + x, + s, + self.ff_proj.weight, + weight1, + weight1_t, + weight1_s, + self.ff_out.weight, + weight2, + weight2_t, + weight2_s, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + og_x = re_x + re_x = self.ff_norm(re_x) + re_x = self.ff_proj(re_x) + re_x = self.act(re_x) + re_x = self.ff_out(re_x) + re_x = og_x + re_x + return re_x, None, None + + +class _CoatOLMoMLPResidual(torch.autograd.Function): + @staticmethod + def forward( + ctx, + re_x, + in_x, + in_s, + weight1_origin, + weight1, + weight1_t, + weight1_s, + weight2_origin, + weight2, + weight2_t, + weight2_s, + group_size, + fwobits, + layer_id, + config, + qargs, + eps=1e-5, + ): + # For autograd + if fwobits["fabit"] == "E4M3": + # in_x = in_x.to(torch.float8_e4m3fn) + in_x = in_x.view(torch.float8_e4m3fn) + else: + raise ValueError("fabit should be E4M3") + + # LayerNorm + ln_x, ln_s, ln_x_t, ln_utils = fp8_layernorm_noparam_forward( + in_x, in_s, group_size, eps, transpose_output_2d=True + ) + + # Linear Layer of Up Projection and Gate Projection. They are fused as one linear layer. + if qargs.weight_memory_efficient: + assert weight1 is None # memory efficient + weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s) + fc1_x, fc1_s = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, True, group_size) + + # NOTE: Becareful of the order + up_x, gate_x = fc1_x.chunk(2, dim=-1) + up_s, gate_s = fc1_s.chunk(2, dim=-1) + + # silu Activation + silu_x, silu_s = fp8_silu_forward(gate_x, gate_s, group_size) + + # Element-wise Multiplication + mul_x, mul_s, mul_x_t = fp8_mul_forward(silu_x, silu_s, up_x, up_s, group_size, transpose_output_2d=True) + + # Output Projection + if weight2 is None: # memory efficient + weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s) + fc2_x = fp8_linear_forward(mul_x, mul_s, weight2, weight2_s, False, group_size) + + # Add the activation together + fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc2_x, mul_x.dtype, group_size) + + # ==================== save for backward ==================== + ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s) + + ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s) + if ( + qargs.weight_memory_efficient + ): # Weight_1/2_origin will not be saved twice, so it will be more memory efficient. + assert weight1_t is None + ctx.weight = (weight1_origin, weight1_s, weight2_origin, weight2_s) + else: # Weight1/2_t is different from the origin weight, so saving it will consumes additional memory footprint. + ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s) + + ctx.group_size = group_size + ctx.ln_utils = ln_utils + ctx.utils = fwobits, layer_id, config, qargs + + out_x = out_x.view(torch.float8_e4m3fn) + + return fp_x, out_x, out_s + + @staticmethod + def backward(ctx, fp_grad, out_g, out_gs): + fwobits, layer_id, config, qargs = ctx.utils + + in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s = ctx.saved_tensors + + (weight1_t, weight1_s, weight2_t, weight2_s) = ctx.weight + group_size = ctx.group_size + mean, rstd, num_warps = ctx.ln_utils + fwobits, layer_id, config, qargs = ctx.utils + + # For autograd + if fwobits["babit"] == "E5M2": + # out_g = out_g.to(torch.float8_e5m2) + out_g = out_g.view(torch.float8_e5m2) + else: + raise ValueError("babit should be E5M2") + out_gs_max = out_gs.max() + + # ==================== Begin backward ==================== + # Output Projection + out_gs = out_gs.max() + out_g_t = fp8_transpose(out_g, transpose_output_2d=True) + + if qargs.weight_memory_efficient: + weight2_t, weight2_s = fp8_division_transpose( + weight2_t, qargs.group_size, fwobits["fwbit"], weight2_s, only_transposed=True + ) + fc2_g, weight2_grad = fp8_linear_backward( + mul_x_t, mul_s, out_g, out_gs_max, out_g_t, weight2_t, weight2_s, group_size + ) + + # [MEM TEST] + del out_g, out_g_t, weight2_t + + # Element-wise Multiplication, 1 means gate, 2 means up + mul_g1, (mul_g2, mul_gs2) = fp8_mul_backward(silu_x, silu_s, up_x, up_s, fc2_g, group_size, fwobits["babit"]) + + # Silu activation + silu_g, silu_gs = fp8_silu_backward(gate_x, gate_s, mul_g1, group_size, fwobits["babit"]) + + # Prepare the input of Linear Layer. NOTE: Becareful of the order + gateup_g = torch.cat([mul_g2, silu_g], dim=-1) + gateup_gs = torch.cat([mul_gs2, silu_gs]) + gateup_gs = torch.max(gateup_gs) + + gateup_g, gateup_gs, gateup_g_t = fp8_division_transpose( + gateup_g, group_size, fwobits["babit"], gateup_gs, stochastic=False + ) + + # Linear Layer of Up and Gate Projection + if qargs.weight_memory_efficient: + weight1_t, weight1_s = fp8_division_transpose( + weight1_t, group_size, fwobits["fwbit"], weight1_s, only_transposed=True + ) + fc1_g, weight1_grad = fp8_linear_backward( + ln_x_t, ln_s, gateup_g, gateup_gs, gateup_g_t, weight1_t, weight1_s, group_size + ) + + # layerNorm + in_g = fp8_layernorm_noparam_backward(in_x, in_s, fc1_g, group_size, mean, rstd, num_warps) + + # Add the gradient together + re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt( + fp_grad, in_g, group_size, fwobits["babit"], stochastic=False + ) + + in_g = in_g.view(torch.float8_e4m3fn) + + return ( + re_g, + in_g, + in_sg_g16, + weight1_grad, + None, + None, + None, + weight2_grad, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +class CoatOLMoBlock(nn.Module): + """ + A base class for transformer block implementations. + """ + + def __init__(self, layer_id: int, config: ModelConfig, qargs: QuantActivationConfig, cache: BufferCache): + super().__init__() + self.layer_id = layer_id + self.config = config + self.qargs = qargs + self.hidden_size = ( + config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model + ) + self.__cache = cache + assert config.d_model % config.n_heads == 0 + + self._activation_checkpoint_fn: Callable | None = None + + # Dropout. + self.dropout = Dropout(config.residual_dropout) + + # Layer norms. + self.k_norm: LayerNormBase | None = None + self.q_norm: LayerNormBase | None = None + if config.attention_layer_norm: + assert config.effective_n_kv_heads is not None + self.k_norm = LayerNormBase.build( + config, + size=(config.d_model // config.n_heads) * config.effective_n_kv_heads, + elementwise_affine=config.attention_layer_norm_with_affine, + ) + self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine) + + # Make sure QKV clip coefficient is positive, otherwise it's not well-defined. + if config.clip_qkv is not None: + assert config.clip_qkv > 0 + + # Activation function. + self.act = Activation.build(config) + assert (self.act.output_multiplier * self.hidden_size) % 1 == 0 + + if not self.qargs.use_quantize_model: + # Attention output projection. + self.attn_out = nn.Linear( + config.d_model, config.d_model, bias=config.include_bias, device=config.init_device + ) + + # Feed-forward output projection. + self.ff_out = nn.Linear( + int(self.act.output_multiplier * self.hidden_size), + config.d_model, + bias=config.include_bias, + device=config.init_device, + ) + self.ff_out._is_residual = True # type: ignore + + # Rotary embeddings. + if self.config.rope: + self.rotary_emb = RotaryEmbedding(config, self.__cache) + + self.flash_attn_func = None + self.flash_attn_varlen_func = None + if config.flash_attention: + try: + from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore + + self.flash_attn_func = flash_attn_func + self.flash_attn_varlen_func = flash_attn_varlen_func + except ModuleNotFoundError: + pass + + def reset_parameters(self): + if self.k_norm is not None: + self.k_norm.reset_parameters() + if self.q_norm is not None: + self.q_norm.reset_parameters() + + if not self.qargs.use_quantize_model: + if self.config.init_fn == InitFnType.normal: + attn_out_std = ff_out_std = self.config.init_std + cutoff_factor = self.config.init_cutoff_factor + + elif self.config.init_fn == InitFnType.mitchell: + attn_out_std = 1 / (math.sqrt(2 * self.config.d_model * (self.layer_id + 1))) + ff_out_std = 1 / (math.sqrt(2 * self.ff_out.in_features * (self.layer_id + 1))) + cutoff_factor = self.config.init_cutoff_factor or 3.0 + + elif self.config.init_fn == InitFnType.full_megatron: + attn_out_std = ff_out_std = self.config.init_std / math.sqrt(2.0 * self.config.n_layers) + cutoff_factor = self.config.init_cutoff_factor or 3.0 + + else: + raise NotImplementedError(self.config.init_fn) + + init_normal(self.attn_out, std=attn_out_std, init_cutoff_factor=cutoff_factor) + init_normal(self.ff_out, std=ff_out_std, init_cutoff_factor=cutoff_factor) + + def set_activation_checkpointing( + self, strategy: ActivationCheckpointingStrategy | None, checkpoint_func: Callable | None = None + ): + if strategy == ActivationCheckpointingStrategy.fine_grained: + self._activation_checkpoint_fn = checkpoint_func or activation_checkpoint_function(self.config) + else: + self._activation_checkpoint_fn = None + + @classmethod + def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor: + target_dtype = input_dtype + # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function + # `is_autocast_cpu_enabled()` for CPU autocast. + # See https://github.com/pytorch/pytorch/issues/110966. + if bias.device.type == "cuda" and torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled(): + target_dtype = torch.get_autocast_cpu_dtype() + if bias.dtype != target_dtype: + bias = bias.to(target_dtype) + ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False) + return bias + + def _scaled_dot_product_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: torch.Tensor | None = None, + dropout_p: float = 0.0, + is_causal: bool = False, + max_doc_len: int | None = None, + cu_doc_lens: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Computes scaled dot product attention on query, key and value tensors, using an optional + attention mask if passed, and applying dropout if a probability greater than 0.0 is specified. + """ + if max_doc_len is not None and cu_doc_lens is not None: + assert self.flash_attn_varlen_func is not None, "flash-attn is required for document masking" + assert attn_mask is None, "attn-mask is currently not supported with document masking" + B, T, D = q.size(0), q.size(2), q.size(3) + r = self.flash_attn_varlen_func( + q.transpose(1, 2).view(B * T, -1, D), + k.transpose(1, 2).view(B * T, -1, D), + v.transpose(1, 2).view(B * T, -1, D), + cu_doc_lens, + cu_doc_lens, + max_doc_len, + max_doc_len, + dropout_p=dropout_p, + causal=is_causal, + ) + return r.view(B, T, -1, D).transpose(1, 2) + elif self.flash_attn_func is not None and attn_mask is None: + r = self.flash_attn_func( + q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=is_causal + ) + return r.transpose(1, 2) + else: + # torch's sdpa doesn't support GQA, so we're doing this + assert k.size(1) == v.size(1) + num_kv_heads = k.size(1) + num_q_heads = q.size(1) + if num_q_heads != num_kv_heads: + assert num_q_heads % num_kv_heads == 0 + k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads) + v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads) + + return F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + ) + + def attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attention_bias: torch.Tensor | None = None, + layer_past: tuple[torch.Tensor, torch.Tensor] | None = None, + use_cache: bool = False, + max_doc_len: int | None = None, + cu_doc_lens: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]: + B, T, C = q.size() # batch size, sequence length, d_model + dtype = k.dtype + + # Optionally apply layer norm to keys and queries. + if self.q_norm is not None and self.k_norm is not None: + q = self.q_norm(q).to(dtype=dtype) + k = self.k_norm(k).to(dtype=dtype) + + # Move head forward to be next to the batch dim. + # shape: (B, nh, T, hs) + q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2) + # shape: (B, n_kv_h, T, hs) + k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2) + # shape: (B, n_kv_h, T, hs) + v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2) + + if layer_past is not None: + past_key, past_value = layer_past + k = torch.cat((past_key, k), dim=-2) + v = torch.cat((past_value, v), dim=-2) + + present = (k, v) if use_cache else None + query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None + + if self.config.rope: + # Apply rotary embeddings. + q, k = self.rotary_emb(q, k) + + if attention_bias is not None: + # Resize and cast attention bias. + # The current dtype of the attention bias might not match the dtype that the SDP attn function will + # run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding + # as down-casting the attention bias to the autocast precision will result in -infs, which will + # cause the SDP attn function to produce NaNs. + attention_bias = self._cast_attn_bias(attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype) + + # Get the attention scores. + # shape: (B, nh, T, hs) + att = self._scaled_dot_product_attention( + q, + k, + v, + attn_mask=attention_bias, + dropout_p=0.0 if not self.training else self.config.attention_dropout, + is_causal=attention_bias is None, + max_doc_len=max_doc_len, + cu_doc_lens=cu_doc_lens, + ) + + # Re-assemble all head outputs side-by-side. + att = att.transpose(1, 2).contiguous().view(B, T, C) + + # Apply output projection. NOTE: We move the attn output outside of this attention function + return att, present + + @abstractmethod + def forward( + self, + x: torch.Tensor, + attention_bias: torch.FloatTensor | None = None, + layer_past: tuple[torch.Tensor, torch.Tensor] | None = None, + use_cache: bool = False, + max_doc_len: int | None = None, + cu_doc_lens: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]: + raise NotImplementedError + + @classmethod + def build(cls, layer_id: int, config: ModelConfig, qargs: QuantActivationConfig, cache: BufferCache) -> OLMoBlock: + if config.block_type == BlockType.sequential: + return CoatOLMoSequentialBlock(layer_id, config, qargs, cache) + elif config.block_type == BlockType.llama: + return CoatOLMoLlamaBlock(layer_id, config, qargs, cache) + else: + raise NotImplementedError(f"Unknown block type: '{config.block_type}'") + + +class CoatOLMoSequentialBlock(CoatOLMoBlock): + """ + This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). To compute it as ``LN(MLP(x + LN(Attention(x))))``, + use the flag `norm_after`. + """ + + def __init__(self, layer_id: int, config: ModelConfig, qargs: QuantActivationConfig, cache: BufferCache): + super().__init__(layer_id, config, qargs, cache) + # Attention input projection. Projects x -> (q, k, v) + + assert not self.config.norm_after, "COAT currently does not support PostNorm" + + head_dim = config.d_model // config.n_heads + self.fused_dims = ( + config.d_model, + config.effective_n_kv_heads * head_dim, + config.effective_n_kv_heads * head_dim, + ) + + if self.qargs.use_quantize_model: + self.BeforeAttention = CoatOLMoBeforeAttentionResidual(config, qargs, self.layer_id, self.fused_dims) + self.AfterAttention = CoatOLMoAfterAttentionResidual(config, qargs, self.layer_id) + self.MLPResidual = CoatOLMoMLPResidual(config, qargs, self.layer_id, self.hidden_size) + else: + self.att_proj = nn.Linear( + config.d_model, sum(self.fused_dims), bias=config.include_bias, device=config.init_device + ) + # Feed-forward input projection. + self.ff_proj = nn.Linear( + config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device + ) + + # Layer norms. + self.attn_norm = LayerNorm.build(config, size=config.d_model) + self.ff_norm = LayerNorm.build(config, size=config.d_model) + + def reset_parameters(self): + super().reset_parameters() + self.attn_norm.reset_parameters() + self.ff_norm.reset_parameters() + # NOTE: the standard deviation for these weights does not depend on the layer. + + if self.qargs.use_quantize_model: # The initialization appears here, not in CoatOLMoBlock's reset_parameters + if self.config.init_fn == InitFnType.normal: + attn_out_std = ff_out_std = self.config.init_std + cutoff_factor = self.config.init_cutoff_factor + + elif self.config.init_fn == InitFnType.mitchell: + attn_out_std = 1 / (math.sqrt(2 * self.config.d_model * (self.layer_id + 1))) + ff_out_std = 1 / (math.sqrt(2 * self.MLPResidual.ff_out.in_features * (self.layer_id + 1))) + cutoff_factor = self.config.init_cutoff_factor or 3.0 + + elif self.config.init_fn == InitFnType.full_megatron: + attn_out_std = ff_out_std = self.config.init_std / math.sqrt(2.0 * self.config.n_layers) + cutoff_factor = self.config.init_cutoff_factor or 3.0 + + else: + raise NotImplementedError(self.config.init_fn) + + init_normal(self.AfterAttention.attn_out, std=attn_out_std, init_cutoff_factor=cutoff_factor) + init_normal(self.MLPResidual.ff_out, std=ff_out_std, init_cutoff_factor=cutoff_factor) + + if self.config.init_fn == InitFnType.normal: + std = self.config.init_std + cutoff_factor = self.config.init_cutoff_factor + elif self.config.init_fn == InitFnType.mitchell: + std = 1 / math.sqrt(self.config.d_model) + cutoff_factor = self.config.init_cutoff_factor or 3.0 + elif self.config.init_fn == InitFnType.full_megatron: + std = self.config.init_std + cutoff_factor = self.config.init_cutoff_factor or 3.0 + else: + raise NotImplementedError(self.config.init_fn) + + if not self.qargs.use_quantize_model: + init_normal(self.att_proj, std, cutoff_factor) + init_normal(self.ff_proj, std, cutoff_factor) + else: + init_normal(self.BeforeAttention.att_proj, std, cutoff_factor) + init_normal(self.MLPResidual.ff_proj, std, cutoff_factor) + + def forward( + self, + x: torch.Tensor, + qx: torch.Tensor, + sx: torch.Tensor, + attention_bias: torch.Tensor | None = None, + layer_past: tuple[torch.Tensor, torch.Tensor] | None = None, + use_cache: bool = False, + max_doc_len: int | None = None, + cu_doc_lens: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]: + # Get query, key, value projections. + # shape: + # - for regular attn q, k, v: (batch_size, seq_len, d_model) + # - for multi-query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_heads) + # - for group query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_kv_heads) + + # import IPython + # IPython.embed() + + if self.qargs.use_quantize_model: + # if False: + x, qkv = self.BeforeAttention(x, qx, sx) + else: + # apply norm before + h = self.attn_norm(x) + + qkv = self.BeforeAttention.att_proj(h) + + if self.config.clip_qkv is not None: + qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + q, k, v = qkv.split(self.fused_dims, dim=-1) + + # Get attention scores. + att, cache = self.attention( + q, + k, + v, + attention_bias, + layer_past=layer_past, + use_cache=use_cache, + max_doc_len=max_doc_len, + cu_doc_lens=cu_doc_lens, + ) + + # import IPython + # IPython.embed() + if self.qargs.use_quantize_model: + # if False: + x, qx, sx = self.AfterAttention(x, att) + else: + att = self.AfterAttention.attn_out(att) + + # Add attention scores. + # shape: (B, T, C) + x = x + self.dropout(att) + + if self.qargs.use_quantize_model: + # if False: + x, qx, sx = self.MLPResidual(x, qx, sx) + else: + # Add feed-forward projection. + # shape: (batch_size, seq_len, d_model) + og_x = x + + x = self.ff_norm(x) + + x = self.MLPResidual.ff_proj(x) + + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.act, x) # type: ignore + else: + x = self.act(x) + x = self.MLPResidual.ff_out(x) + + x = self.dropout(x) + x = og_x + x + + # import IPython + # IPython.embed() + + return x, qx, sx, cache + + +class CoatOLMoLlamaBlock(OLMoBlock): + """ + This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). This block is similar to `OLMoSequentialBlock` + but some operations have slightly different implementations to imitate the + behavior of Llama. + """ + + def __init__(self, layer_id: int, config: ModelConfig, qargs: QuantActivationConfig, cache: BufferCache): + super().__init__(layer_id, config, qargs, cache) + # Layer norms. + self.attn_norm = LayerNorm.build(config) + self.ff_norm = LayerNorm.build(config) + self.__cache = cache + + # Attention input projection. Projects x -> (q, k, v) + if config.multi_query_attention: + q_proj_out_dim = config.d_model + k_proj_out_dim = config.d_model // config.n_heads + v_proj_out_dim = config.d_model // config.n_heads + else: + q_proj_out_dim = config.d_model + k_proj_out_dim = config.d_model + v_proj_out_dim = config.d_model + self.q_proj = nn.Linear(config.d_model, q_proj_out_dim, bias=config.include_bias, device=config.init_device) + self.k_proj = nn.Linear(config.d_model, k_proj_out_dim, bias=config.include_bias, device=config.init_device) + self.v_proj = nn.Linear(config.d_model, v_proj_out_dim, bias=config.include_bias, device=config.init_device) + + # Feed-forward input projection. + self.ff_proj = nn.Linear(config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device) + + def reset_parameters(self): + super().reset_parameters() + self.attn_norm.reset_parameters() + self.ff_norm.reset_parameters() + # NOTE: the standard deviation for these weights does not depend on the layer. + + if self.config.init_fn == InitFnType.normal: + std = self.config.init_std + cutoff_factor = self.config.init_cutoff_factor + elif self.config.init_fn == InitFnType.mitchell: + std = 1 / math.sqrt(self.config.d_model) + cutoff_factor = self.config.init_cutoff_factor or 3.0 + elif self.config.init_fn == InitFnType.full_megatron: + std = self.config.init_std + cutoff_factor = self.config.init_cutoff_factor or 3.0 + else: + raise NotImplementedError(self.config.init_fn) + + init_normal(self.q_proj, std, cutoff_factor) + init_normal(self.k_proj, std, cutoff_factor) + init_normal(self.v_proj, std, cutoff_factor) + init_normal(self.ff_proj, std, cutoff_factor) + + def _scaled_dot_product_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: torch.Tensor | None = None, + dropout_p: float = 0.0, + is_causal: bool = False, + max_doc_len: int | None = None, + cu_doc_lens: torch.Tensor | None = None, + ) -> torch.Tensor: + if max_doc_len is not None or cu_doc_lens is not None: + raise NotImplementedError(f"attention document masking is not implemented for {self.__class__.__name__}") + + attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1)) + + if is_causal: + assert attn_mask is None + + query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None + attn_bias = get_causal_attention_bias(self.__cache, key_len, q.device)[:, :, :query_len, :key_len] + elif attn_mask is not None: + attn_bias = attn_mask.to(q.dtype) + else: + attn_bias = torch.zeros_like(attn_weights) + + attn_weights += attn_bias + attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(q.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout_p) + return torch.matmul(attn_weights, v) + + def forward( + self, + x: torch.Tensor, + qx: torch.Tensor, + sx: torch.Tensor, + attention_bias: torch.Tensor | None = None, + layer_past: tuple[torch.Tensor, torch.Tensor] | None = None, + use_cache: bool = False, + max_doc_len: int | None = None, + cu_doc_lens: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]: + # Get query, key, value projections. + # shape: + # - for regular attn q, k, v: (batch_size, seq_len, d_model) + # - for multi-query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_heads) + x_normed = self.attn_norm(x) + q = self.q_proj(x_normed) + k = self.k_proj(x_normed) + v = self.v_proj(x_normed) + + if self.config.clip_qkv is not None: + q.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + k.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + v.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + # Get attention scores. + att, cache = self.attention( + q, + k, + v, + attention_bias, + layer_past=layer_past, + use_cache=use_cache, + max_doc_len=max_doc_len, + cu_doc_lens=cu_doc_lens, + ) + + att = self.attn_out(att) # NOTE: we move the attn_out outside the self.attention module + + # Add attention scores. + # shape: (B, T, C) + x = x + self.dropout(att) + + # Add feed-forward projection. + # shape: (batch_size, seq_len, d_model) + og_x = x + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore + else: + x = self.ff_norm(x) + x = self.ff_proj(x) + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.act, x) # type: ignore + else: + x = self.act(x) + x = self.ff_out(x) + x = self.dropout(x) + x = og_x + x + + return x, cache + + +class CoatOLMoBlockGroup(nn.ModuleList): + def __init__(self, config: ModelConfig, layer_offset: int, modules: Iterable[nn.Module] | None = None): + super().__init__(modules) + self.config = config + self.layer_offset = layer_offset + self.activation_checkpointing_strategy: ActivationCheckpointingStrategy | None = None + self._activation_checkpoint_fn = activation_checkpoint_function(self.config) + + def forward( + self, + x: torch.Tensor, + attention_bias: torch.FloatTensor | None = None, + layers_past: list[tuple[torch.Tensor, torch.Tensor]] | None = None, + use_cache: bool = False, + max_doc_len: int | None = None, + cu_doc_lens: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]] | None]: + attn_key_values: list[tuple[torch.Tensor, torch.Tensor]] | None = [] if use_cache else None + for block_idx, block in enumerate(self): + layer_past = None if layers_past is None else layers_past[block_idx] + block_idx += self.layer_offset + if should_checkpoint_block(self.activation_checkpointing_strategy, block_idx): + # shape: (batch_size, seq_len, d_model) + x, cache = self._activation_checkpoint_fn( # type: ignore + block, + x, + attention_bias=attention_bias, + layer_past=layer_past, + use_cache=use_cache, + max_doc_len=max_doc_len, + cu_doc_lens=cu_doc_lens, + ) + else: + # shape: (batch_size, seq_len, d_model) + x, cache = block( + x, + attention_bias=attention_bias, + layer_past=layer_past, + use_cache=use_cache, + max_doc_len=max_doc_len, + cu_doc_lens=cu_doc_lens, + ) + if attn_key_values is not None: + assert cache is not None + attn_key_values.append(cache) + return x, attn_key_values + + def reset_parameters(self): + for block in self: + block.reset_parameters() + + def set_activation_checkpointing( + self, strategy: ActivationCheckpointingStrategy | None, checkpoint_func: Callable | None = None + ): + self.activation_checkpointing_strategy = strategy + for block in self: + block.set_activation_checkpointing(strategy, checkpoint_func=checkpoint_func) + + +class CoatOLMo(nn.Module): + def __init__(self, config: ModelConfig, qargs: QuantActivationConfig, init_params: bool = True): + super().__init__() + self.config = config + self.qargs = qargs + self.__cache = BufferCache() + + # Validate config. + if self.config.alibi and self.config.flash_attention: + raise OLMoConfigurationError("ALiBi is currently not supported with FlashAttention") + + if self.config.alibi and self.config.rope: + raise OLMoConfigurationError("ALiBi and RoPE are mutually exclusive") + + if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size: + if self.config.embedding_size < self.config.vocab_size: + raise OLMoConfigurationError("embedding size should be at least as big as vocab size") + elif self.config.embedding_size % 128 != 0: + import warnings + + warnings.warn( + "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning + ) + + self.activation_checkpointing_strategy: ActivationCheckpointingStrategy | None = None + self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config) + + if not ( + 0 < self.config.block_group_size <= self.config.n_layers + and self.config.n_layers % self.config.block_group_size == 0 + ): + raise OLMoConfigurationError("n layers must be divisible by block group size") + + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it + + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.embedding_size or config.vocab_size, config.d_model, device=config.init_device), + emb_drop=Dropout(config.embedding_dropout), + ln_f=LayerNorm.build(config), + ) + ) + + blocks = [CoatOLMoBlock.build(i, config, qargs, self.__cache) for i in range(config.n_layers)] + if self.config.block_group_size > 1: + block_groups = [ + CoatOLMoBlockGroup(config, i, blocks[i : i + config.block_group_size]) + for i in range(0, config.n_layers, config.block_group_size) + ] + self.transformer.update({"block_groups": nn.ModuleList(block_groups)}) + else: + self.transformer.update({"blocks": nn.ModuleList(blocks)}) + + if not (self.config.alibi or self.config.rope): + self.transformer.update( + {"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)} + ) + if not config.weight_tying: + self.transformer.update( + { + "ff_out": nn.Linear( + config.d_model, + config.embedding_size or config.vocab_size, + bias=config.include_bias, + device=config.init_device, + ) + } + ) + if config.embedding_layer_norm: + self.transformer.update({"emb_norm": LayerNorm.build(config)}) + + # When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights. + if init_params and self.config.init_device != "meta": + self.reset_parameters() + self.__num_fwd_flops: int | None = None + self.__num_bck_flops: int | None = None + + # Warm up cache. + if self.config.alibi: + get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config)) + self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config)) + + # Quantize + self.quantize_input_before_block = Coat_quantize_bgn(qargs) + self.quantize_output_after_block = Coat_quantize_end(qargs) + + set_activation_checkpointing = OLMo.set_activation_checkpointing + device = OLMo.device + reset_parameters = OLMo.reset_parameters + get_alibi_attention_bias = OLMo.get_alibi_attention_bias + + def forward( + self, + input_ids: torch.LongTensor, + input_embeddings: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + attention_bias: torch.Tensor | None = None, + past_key_values: Sequence[tuple[torch.Tensor, torch.Tensor]] | None = None, + use_cache: bool = False, + last_logits_only: bool = False, + output_hidden_states: bool | None = None, + doc_lens: torch.Tensor | None = None, + max_doc_lens: Sequence[int] | None = None, + ) -> OLMoOutput: + """ + :param input_ids: A tensor of shape `(batch_size, seq_len)`. + :param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input + embeddings. When provided, it is treated as the output of the input embedding layer. + :param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates + which input IDs are masked. A `1` value in the mask means that + the corresponding input ID should *not* be ignored. A `0` means + that the corresponding input ID is masked. + + This has the same meaning as the `attention_mask` in HuggingFace's `transformers` + library. + :param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`, + `(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used + to introduce causal or other biases. + + If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]` + indicates that the i-th element in the sequence is allowed to attend to the j-th + element in the sequence. + + If the tensor is a float tensor, it will just be added to the attention + scores before the softmax. + + The default is causal, which corresponds to a lower-diagonal byte matrix of ones. + :param past_key_values: Pre-computed keys and values for each attention block. + Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + :param use_cache: If `True`, return key and value tensors for each block. + :param last_logits_only: If `True`, only compute the logits for the last token of each sequence. + This can speed up decoding when you only care about the next token. + :param doc_lens: Document lengths to use in attention for intra-document masking. + Shape `(batch_size, max_docs)`. + :param max_doc_lens: Maximum document length for each instance in the batch. + """ + output_hidden_states = output_hidden_states if output_hidden_states is not None else False + + if past_key_values: + assert len(past_key_values) == self.config.n_layers + + batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2] + if past_key_values is None: + past_length = 0 + else: + past_length = past_key_values[0][0].size(-2) + + max_doc_len: int | None = None + cu_doc_lens: torch.Tensor | None = None + if doc_lens is not None and max_doc_lens is not None: + max_doc_len = max(max_doc_lens) + cu_doc_lens = get_cumulative_document_lengths(doc_lens) + + # Get embeddings of input. + # shape: (batch_size, seq_len, d_model) + x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore + + # Apply embedding layer norm. + if self.config.embedding_layer_norm: + x = self.transformer.emb_norm(x) + + if not (self.config.alibi or self.config.rope): + # Get positional embeddings. + # shape: (1, seq_len) + pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0) + # shape: (1, seq_len, d_model) + pos_emb = self.transformer.wpe(pos) # type: ignore + x = pos_emb + x + + # Apply dropout. + # shape: (batch_size, seq_len, d_model) + x = self.transformer.emb_drop(x) # type: ignore + + # Transform the attention mask into what the blocks expect. + if attention_mask is not None: + # shape: (batch_size, 1, 1, seq_len) + attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :] + attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min + + # Merge attention mask with attention bias. + if ( + attention_bias is not None + or attention_mask is not None + or self.config.alibi + # NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly + # with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute + # scores correctly. + or past_key_values is not None + ): + if attention_bias is None and self.config.alibi: + attention_bias = get_causal_attention_bias( + self.__cache, past_length + seq_len, x.device + ) + self.get_alibi_attention_bias(past_length + seq_len, x.device) + elif attention_bias is None: + attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device) + elif attention_bias.dtype in (torch.int8, torch.bool): + attention_bias = attention_bias.to(dtype=torch.float) + attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min) + + # Transform to the right shape and data type. + mask_len = seq_len + if attention_mask is not None: + mask_len = attention_mask.shape[-1] + elif past_key_values is not None: + mask_len = past_key_values[0][0].shape[-2] + seq_len + attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float) + + # Add in the masking bias. + if attention_mask is not None: + attention_bias = attention_bias + attention_mask + # Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf. + # `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead + # it can produce NaNs. + ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False) + + attn_key_values: list[tuple[torch.Tensor, torch.Tensor]] | None = [] if use_cache else None + + # decoder layers + all_hidden_states = [] + + # Prepare the input for COAT decoderlayer + x, qx, sx = self.quantize_input_before_block(x) + + # Apply blocks one-by-one. + if self.config.block_group_size == 1: + for block_idx, block in enumerate(self.transformer.blocks): + if output_hidden_states: + # add hidden states + all_hidden_states.append(x) + + layer_past = None if past_key_values is None else past_key_values[block_idx] + if should_checkpoint_block(self.activation_checkpointing_strategy, block_idx): + # shape: (batch_size, seq_len, d_model) + x, qx, sx, cache = self._activation_checkpoint_fn( + block, + x, + qx, + sx, + attention_bias=attention_bias, + layer_past=layer_past, + use_cache=use_cache, + max_doc_len=max_doc_len, + cu_doc_lens=cu_doc_lens, + ) + else: + # shape: (batch_size, seq_len, d_model) + x, qx, sx, cache = block( + x, + qx, + sx, + attention_bias=attention_bias, + layer_past=layer_past, + use_cache=use_cache, + max_doc_len=max_doc_len, + cu_doc_lens=cu_doc_lens, + ) + + if attn_key_values is not None: + assert cache is not None + attn_key_values.append(cache) + else: + for group_idx, block_group in enumerate(self.transformer.block_groups): + if output_hidden_states: + # add hidden states + all_hidden_states.append(x) + + layers_past = ( + None + if past_key_values is None + else past_key_values[ + group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size + ] + ) + x, cache = block_group( + x, + attention_bias=attention_bias, + layers_past=layers_past, + use_cache=use_cache, + max_doc_len=max_doc_len, + cu_doc_lens=cu_doc_lens, + ) + if attn_key_values is not None: + assert cache is not None + attn_key_values.extend(cache) + + # Summarize the output of the Decoder Layer + x = self.quantize_output_after_block(x, qx, sx) + + if last_logits_only: + # shape: (batch_size, 1, d_model) + x = x[:, -1, :].unsqueeze(1) + + # Apply final layer norm. + # shape: (batch_size, seq_len or 1, d_model) + x = self.transformer.ln_f(x) # type: ignore + if output_hidden_states: + # add final hidden state post-final-layernorm, following HuggingFace's convention + all_hidden_states.append(x) + + # Get logits. + # shape: (batch_size, seq_len or 1, vocab_size) + if self.config.weight_tying: + logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore + else: + logits = self.transformer.ff_out(x) # type: ignore + if self.config.scale_logits: + logits.mul_(1 / math.sqrt(self.config.d_model)) + + return OLMoOutput( + logits=logits, + attn_key_values=attn_key_values, + hidden_states=tuple(all_hidden_states) if output_hidden_states else None, + ) + + def get_fsdp_wrap_policy(self, wrap_strategy: FSDPWrapStrategy | None = None): + if wrap_strategy is None: + return None + + # The 'recurse' mode for the wrap function does not behave like you'd expect. + # Even if we return False, it may still recurse because PyTorch does what it wants, + # not what you want. This causes issues when, for example, we want to wrap 'ff_out' (a linear layer) + # but not other linear layers within a block. + # So we have to explicitly tell PyTorch which linear layers to wrap, and we also just + # return True in 'recurse' mode for simplicity. + size_based_module_to_wrap = {self.transformer.wte} + if hasattr(self.transformer, "ff_out"): + size_based_module_to_wrap.add(self.transformer.ff_out) + + if wrap_strategy == FSDPWrapStrategy.by_block: + + def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): + del nonwrapped_numel + wrap = isinstance(module, CoatOLMoBlock) + if recurse: + return True + else: + return wrap + + return fsdp_wrap_fn + elif wrap_strategy == FSDPWrapStrategy.by_block_and_size: + + def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): + del nonwrapped_numel + wrap = isinstance(module, (CoatOLMoBlock,)) or module in size_based_module_to_wrap + if recurse: + return True + else: + return wrap + + return fsdp_wrap_fn + elif wrap_strategy == FSDPWrapStrategy.by_block_group: + if self.config.block_group_size <= 1: + raise OLMoConfigurationError( + "'by_block_group' FSDP wrapping strategy requires block group size greater than 1" + ) + + def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): + del nonwrapped_numel + wrap = isinstance(module, CoatOLMoBlockGroup) + if recurse: + return True + else: + return wrap + + return fsdp_wrap_fn + elif wrap_strategy == FSDPWrapStrategy.by_block_group_and_size: + if self.config.block_group_size <= 1: + raise OLMoConfigurationError( + "'by_block_group_and_size' FSDP wrapping strategy requires block group size greater than 1" + ) + + def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): + del nonwrapped_numel + wrap = isinstance(module, (CoatOLMoBlockGroup,)) or module in size_based_module_to_wrap + if recurse: + return True + else: + return wrap + + return fsdp_wrap_fn + elif wrap_strategy == FSDPWrapStrategy.size_based: + from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy + + return size_based_auto_wrap_policy + elif wrap_strategy in { + FSDPWrapStrategy.one_in_two, + FSDPWrapStrategy.one_in_three, + FSDPWrapStrategy.one_in_four, + FSDPWrapStrategy.one_in_five, + }: + c = { + FSDPWrapStrategy.one_in_two: 2, + FSDPWrapStrategy.one_in_three: 3, + FSDPWrapStrategy.one_in_four: 4, + FSDPWrapStrategy.one_in_five: 5, + }[wrap_strategy] + + def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): + del nonwrapped_numel + wrap = isinstance(module, CoatOLMoBlock) and module.layer_id % c == 0 + if recurse: + return True + else: + return wrap + + return fsdp_wrap_fn + else: + raise NotImplementedError(wrap_strategy) + + num_params = OLMo.num_params + + @property + def num_fwd_flops(self): + if self.__num_fwd_flops: + return self.__num_fwd_flops + + # embedding table is just a lookup in the forward pass + n_params = self.num_params(include_embedding=False) + # the number of parameters is approximately the number of multiply-accumulates (MAC) in the network + # each MAC has 2 FLOPs - we multiply by 2 ie 2 * n_param + # this gets us FLOPs / token + params_flops_per_token = 2 * n_params + # there are 2 FLOPS per mac; there is A=Q*K^T and out=A*V ops (ie mult by 2) + attn_flops_per_token = self.config.n_layers * 2 * 2 * (self.config.d_model * self.config.max_sequence_length) + self.__num_fwd_flops = params_flops_per_token + attn_flops_per_token + return self.__num_fwd_flops + + @property + def num_bck_flops(self): + if self.__num_bck_flops: + return self.__num_bck_flops + + n_params = self.num_params() + params_flops_per_token = 4 * n_params + attn_flops_per_token = self.config.n_layers * 8 * (self.config.d_model * self.config.max_sequence_length) + self.__num_bck_flops = params_flops_per_token + attn_flops_per_token + return self.__num_bck_flops + + generate = OLMo.generate + + @classmethod + def from_checkpoint( + cls, checkpoint_dir: PathOrStr, device: str = "cpu", checkpoint_type: CheckpointType | None = None + ) -> CoatOLMo: + """ + Load an OLMo model from a checkpoint. + """ + from olmo.util import resource_path + + # Guess checkpoint type. + if checkpoint_type is None: + try: + if resource_path(checkpoint_dir, "model.pt").is_file(): + checkpoint_type = CheckpointType.unsharded + else: + checkpoint_type = CheckpointType.sharded + except FileNotFoundError: + checkpoint_type = CheckpointType.sharded + + # Load config. + config_path = resource_path(checkpoint_dir, "config.yaml") + model_config = ModelConfig.load(config_path, key="model", validate_paths=False) + + if checkpoint_type == CheckpointType.unsharded: + # Initialize model (always on CPU to start with so we don't run out of GPU memory). + model_config.init_device = "cpu" + model = CoatOLMo(model_config) + + # Load state dict directly to target device. + state_dict_path = resource_path(checkpoint_dir, "model.pt") + state_dict = torch.load(state_dict_path, map_location="cpu") + model.load_state_dict(model._make_state_dict_compatible(state_dict)[0]) + model = model.to(torch.device(device)) + else: + train_config = TrainConfig.load(config_path) + if train_config.sharded_checkpointer == ShardedCheckpointerType.olmo_core: + from olmo_core.distributed.checkpoint import load_model_and_optim_state # type: ignore + + model_config.init_device = device + model = CoatOLMo(model_config) + load_model_and_optim_state(checkpoint_dir, model) + else: + # train_config.sharded_checkpointer == ShardedCheckpointerType.torch_new + from olmo.checkpoint import load_model_state + + # Initialize model on target device. In this case the state dict is loaded in-place + # so it's not necessary to start on CPU if the target device is a GPU. + model_config.init_device = device + model = CoatOLMo(model_config) + + # Load state dict in place. + load_model_state(checkpoint_dir, model) + + return model.eval() + + def _make_state_dict_compatible( + self, state_dict: dict[str, torch.Tensor] + ) -> tuple[dict[str, torch.Tensor], dict[str, set[str]]]: + """ + Handles some cases where the state dict is valid yet may need to be transformed in order to + be loaded. + + This modifies the state dict in-place and also returns it, along with a mapping of original key + names to new key names in cases where the keys were simply renamed. That mapping can be used + to make a corresponding optimizer state dict compatible as well. + """ + import re + from fnmatch import fnmatch + + new_keys_to_og_keys: dict[str, str] = {} + + # Remove "_fsdp_wrapped_module." prefix from all keys. We don't want this prefix when the model is + # not wrapped in FSDP. And when the model is wrapped in FSDP, loading this state dict will still work + # fine without the prefixes. This also simplifies the other steps below. + for key in list(state_dict.keys()): + state_dict[(new_key := key.replace("_fsdp_wrapped_module.", ""))] = state_dict.pop(key) + new_keys_to_og_keys[new_key] = key + + # For backwards compatibility prior to fixing https://github.com/allenai/LLM/issues/222 + if self.config.block_type == BlockType.sequential: + for key in list(state_dict.keys()): + if fnmatch(key, "transformer.*.norm.weight"): + tensor = state_dict.pop(key) + state_dict[(new_key := key.replace("norm.weight", "attn_norm.weight"))] = tensor + new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key] + state_dict[(new_key := key.replace("norm.weight", "ff_norm.weight"))] = tensor.clone() + new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key] + del new_keys_to_og_keys[key] + elif fnmatch(key, "transformer.*.norm.bias"): + tensor = state_dict.pop(key) + state_dict[(new_key := key.replace("norm.bias", "attn_norm.bias"))] = tensor + new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key] + state_dict[(new_key := key.replace("norm.bias", "ff_norm.bias"))] = tensor.clone() + new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key] + del new_keys_to_og_keys[key] + + # Realquantization will change the place the linear layers happen + if self.qargs.use_quantize_model == "coat_real": + for key in list(state_dict.keys()): + if fnmatch(key, "transformer.blocks.*.att_proj.weight") and "BeforeAttention" not in key: + tensor = state_dict.pop(key) + state_dict[(new_key := key.replace("att_proj.weight", "BeforeAttention.att_proj.weight"))] = tensor + new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key] + del new_keys_to_og_keys[key] + elif fnmatch(key, "transformer.blocks.*.attn_out.weight") and "AfterAttention" not in key: + tensor = state_dict.pop(key) + state_dict[(new_key := key.replace("attn_out.weight", "AfterAttention.attn_out.weight"))] = tensor + new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key] + del new_keys_to_og_keys[key] + elif fnmatch(key, "transformer.blocks.*.ff_proj.weight") and "MLPResidual" not in key: + tensor = state_dict.pop(key) + state_dict[(new_key := key.replace("ff_proj.weight", "MLPResidual.ff_proj.weight"))] = tensor + new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key] + del new_keys_to_og_keys[key] + elif fnmatch(key, "transformer.blocks.*.ff_out.weight") and "MLPResidual" not in key: + tensor = state_dict.pop(key) + state_dict[(new_key := key.replace("ff_out.weight", "MLPResidual.ff_out.weight"))] = tensor + new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key] + del new_keys_to_og_keys[key] + + # For loading a state dict that was saved with a different `block_group_size`. + if "transformer.block_groups.0.0.attn_out.weight" in state_dict.keys(): + state_dict_block_group_size = len( + [k for k in state_dict.keys() if fnmatch(k, "transformer.block_groups.0.*.attn_out.weight")] + ) + else: + state_dict_block_group_size = 1 + if self.config.block_group_size != state_dict_block_group_size: + log.info( + f"Regrouping state dict blocks from group size {state_dict_block_group_size} to " + f"group size {self.config.block_group_size}" + ) + # For simplicity we're first going to flatten out the block groups in the state dict (if necessary) + # and then (re-)group them into the right block sizes. + if state_dict_block_group_size > 1: + for key in list(state_dict.keys()): + if (m := re.match(r"transformer.block_groups\.(\d+)\.(\d+)\..*", key)) is not None: + group_idx, group_block_idx = int(m.group(1)), int(m.group(2)) + block_idx = (group_idx * state_dict_block_group_size) + group_block_idx + state_dict[ + ( + new_key := key.replace( + f"block_groups.{group_idx}.{group_block_idx}.", f"blocks.{block_idx}." + ) + ) + ] = state_dict.pop(key) + new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key) + + if self.config.block_group_size > 1: + # Group the state dict blocks into the right block size. + for key in list(state_dict.keys()): + if (m := re.match(r"transformer.blocks\.(\d+)\..*", key)) is not None: + block_idx = int(m.group(1)) + group_idx, group_block_idx = ( + block_idx // self.config.block_group_size, + block_idx % self.config.block_group_size, + ) + state_dict[ + ( + new_key := key.replace( + f"blocks.{block_idx}.", f"block_groups.{group_idx}.{group_block_idx}." + ) + ) + ] = state_dict.pop(key) + new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key) + + og_keys_to_new: dict[str, set[str]] = defaultdict(set) + for new_key, og_key in new_keys_to_og_keys.items(): + og_keys_to_new[og_key].add(new_key) + + return state_dict, og_keys_to_new diff --git a/llava/model/coat/activation/real_quantization/__init__.py b/llava/model/coat/activation/real_quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..61875efb5d6a495fdab36438ea99a8784a88c4cf --- /dev/null +++ b/llava/model/coat/activation/real_quantization/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Activation +# Utils +from ._dequantize import fp8_dequantize +from ._division import fp8_division +from ._division_transpose import fp8_division_transpose +from ._quantize import fp8_quantize +from ._quantize_pertensor import fp8_quantize_pertensor +from ._quantize_pertensor_transpose import fp8_quantize_pertensor_transpose +from ._transpose import fp8_transpose +from .add_bwd import fp8_add_Ifp_Ifp_Ofp_Opt +from .add_fwd import fp8_add_Ifp_Ifp_Ofp_Og16 + +# Normalization +from .func_layernorm_noparam import fp8_layernorm_noparam_backward, fp8_layernorm_noparam_forward +from .func_quantize import Coat_quantize_bgn, Coat_quantize_end +from .func_rmsnorm import fp8_rmsnorm_backward, fp8_rmsnorm_forward +from .gelu_bwd import fp8_gelu_backward +from .gelu_fwd import fp8_gelu_forward + +# linear and add +from .linear import fp8_linear_backward, fp8_linear_forward +from .mul_bwd import fp8_mul_backward +from .mul_fwd import fp8_mul_forward +from .silu_bwd import fp8_silu_backward +from .silu_fwd import fp8_silu_forward diff --git a/llava/model/coat/activation/real_quantization/_dequantize.py b/llava/model/coat/activation/real_quantization/_dequantize.py new file mode 100644 index 0000000000000000000000000000000000000000..4347932f9c3691ffada9d59258337aac75b74942 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/_dequantize.py @@ -0,0 +1,162 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block + +"""Quantize Operator""" +"""Input uses 1 * 16 group quantization""" +"""Output uses 1 * 16 group quantization""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_dequantize_kernel( + output_ptr, # output + input_ptr, + input_scale_ptr, # input + M, + N, + SN, + QB: tl.constexpr, # shape + input_stride_0, + input_stride_1, # input stride + s_input_stride_0, + s_input_stride_1, # scale of output stride + output_stride_0, + output_stride_1, # output stride + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + # input ptr + scale_input_ptr = tl.make_block_ptr( + base=input_scale_ptr, + shape=(M, SN), + strides=(s_input_stride_0, s_input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + scale_input = tl.load(scale_input_ptr) + + input = input.to(tl.float32) + scale_input = scale_input.to(tl.float32) + + # Dequantize and gelu calculation + scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1)) + input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) + output = input * scale_input + + output = tl.reshape(output, (BLOCK_M, BLOCK_N)) + output = output.to(output_ptr.dtype.element_ty) + + # debug + # gelu_output = input + # scale_output = scale_input + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(M, N), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + tl.store(output_block_ptr, output, boundary_check=(0, 1)) + + +def fp8_dequantize(x, s_x, QB): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + s_x = s_x.reshape(-1, s_x.shape[-1]) + + # defining the input and output tensor + M, N = x.shape + SN = N // QB + + y = torch.empty_like(x, dtype=torch.bfloat16) + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_dequantize_kernel[grid]( + y, + x, + s_x, + M, + N, + SN, + QB, + x.stride(0), + x.stride(1), + s_x.stride(0), + s_x.stride(1), + y.stride(0), + y.stride(1), + ) + + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + + return y diff --git a/llava/model/coat/activation/real_quantization/_division.py b/llava/model/coat/activation/real_quantization/_division.py new file mode 100644 index 0000000000000000000000000000000000000000..335ff1a67795dc858a740c6618466ff79fb554cb --- /dev/null +++ b/llava/model/coat/activation/real_quantization/_division.py @@ -0,0 +1,212 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block + +"""Quantize and Transpose Operator""" +"""Input uses 1 * 16 group quantization""" +"""Output uses 1 * 16 group quantization""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_division_kernel( + output_ptr, # output + input_ptr, + input_scale_ptr, # input + noise_ptr, # noise for stochastic + M, + N, + SN, + QB: tl.constexpr, + fp8_max, + e_bit: tl.constexpr, + m_bit: tl.constexpr, # shape + input_stride_0, + input_stride_1, # input stride + output_stride_0, + output_stride_1, # output stride + SCALE_MIN_THRES: tl.constexpr, # We do not use it since we believe SCALE_MIN_THRES should be used in previous kernel when calculating scaling factor + STOCHASTIC: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + input = input.to(tl.float32) + scale_output = tl.load(input_scale_ptr) + scale_output = scale_output.to(tl.float32) + + output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) + + # Quantize Scale calculation + # Quantize + output = tl.fdiv(output, scale_output) + output = tl.reshape(output, (BLOCK_M, BLOCK_N)) + + if STOCHASTIC: + # noise_block_ptr = tl.make_block_ptr( + # base=noise_ptr, + # shape=(M, N), + # strides=(input_stride_0, input_stride_1), + # offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + # block_shape=(BLOCK_M, BLOCK_N), + # order=(1, 0) + # ) + # noise = tl.load(noise_block_ptr) + + offs_m = pid_dim0 * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_dim1 * BLOCK_N + tl.arange(0, BLOCK_N) + noise_offset = offs_m[:, None] * input_stride_0 + offs_n[None, :] * input_stride_1 + noise = tl.rand(0, noise_offset) + + output = _stochastic_rounding(output, noise, e_bit, m_bit) + + output = output.to(output_ptr.type.element_ty) + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(M, N), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + tl.store(output_block_ptr, output, boundary_check=(0, 1)) + + +@triton.jit +def _stochastic_rounding(output, noise, e_bit: tl.constexpr, m_bit: tl.constexpr): + subnormal_min = tl.exp2(2 - tl.exp2(e_bit - 1) - m_bit) + # subnormal_should_be = tl.exp2(2 - tl.exp2(e_bit) - 1) + + output_int32 = tl.cast(output, tl.int32, bitcast=True) + output_int32 = output_int32 & 0x7F800000 + output_float32 = tl.cast(output_int32, tl.float32, bitcast=True) + output_exp = tl.maximum(output_float32, subnormal_min) + + noise_rescale = tl.exp2(m_bit) + (output_exp == subnormal_min) * ( + 1 - tl.exp2(m_bit) + ) # 2^m_bit for normal, 1 for subnormal + + noise = output_exp * noise / noise_rescale + sign = 1 - 2 * libdevice.signbit(output) + output = tl.abs(output) + noise + + minmax_ratio = 2 + (output_exp == subnormal_min) * (tl.exp2(m_bit) - 2) # 2 for normal, and 2^M for subnormal + output = sign * tl.clamp(output, min=output_exp, max=minmax_ratio * output_exp) + + return output + + +def fp8_division(x, QB, fp8type, s_y=None, stochastic=False): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + + if stochastic: + # noise = torch.zeros_like(x, dtype=torch.float32).uniform_(-0.5, 0.5) + noise = None + else: + noise = None + + # defining the input and output tensor + M, N = x.shape + SN = N // QB + + if isinstance(fp8type, str): + fp8type = convert_str_to_fp8[fp8type] + + y = torch.empty_like(x, dtype=fp8type) + fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value + e_bit, m_bit = convert_fp8_to_embit[fp8type] + + if s_y is None: + s_y = (x.abs().max() + SCALE_MIN_THRES) / fp8MaxValue + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_division_kernel[grid]( + y, + x, + s_y, + noise, + M, + N, + SN, + QB, + fp8MaxValue, + e_bit, + m_bit, + x.stride(0), + x.stride(1), + y.stride(0), + y.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + STOCHASTIC=stochastic, + ) + + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + + return y, s_y # y_t is expected to be 2D tensor diff --git a/llava/model/coat/activation/real_quantization/_division_transpose.py b/llava/model/coat/activation/real_quantization/_division_transpose.py new file mode 100644 index 0000000000000000000000000000000000000000..7fbabaa31e4f47482c98d69431a9dfae488a2f15 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/_division_transpose.py @@ -0,0 +1,215 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from ._division import _stochastic_rounding +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block + +"""Division and Transpose Operator""" +"""Input uses full-precision/BF16""" +"""Output uses per tensor quantization""" +"""Output_t uses per tensor quantization and is transposed, but is flattened to 2D""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), # triton.Config({'BLOCK_M': 1, 'BLOCK_N': 16}, num_stages=4, num_warps=1,) + # configs=[triton.Config({'BLOCK_M': 1, 'BLOCK_N': 16}, num_stages=4, num_warps=1,)], # + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_division_transpose_kernel( + output_ptr, + output_t_ptr, # output + input_ptr, + input_scale_ptr, # input + noise_ptr, # noise for stochastic + M, + N, + SN, + QB: tl.constexpr, + fp8_max, + e_bit, + m_bit, # shape + input_stride_0, + input_stride_1, # input stride + output_stride_0, + output_stride_1, # output stride + output_t_stride_0, + output_t_stride_1, # output stride + SCALE_MIN_THRES: tl.constexpr, # We do not use it since we believe SCALE_MIN_THRES should be used in previous kernel when calculating scaling factor + STOCHASTIC: tl.constexpr, + ONLY_TRANSPOSED: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + input = input.to(tl.float32) + scale_output = tl.load(input_scale_ptr) + scale_output = scale_output.to(tl.float32) + + output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) + + # Quantize Scale calculation + # Quantize + output = tl.fdiv(output, scale_output) + output = tl.reshape(output, (BLOCK_M, BLOCK_N)) + + if STOCHASTIC: + # noise_block_ptr = tl.make_block_ptr( + # base=noise_ptr, + # shape=(M, N), + # strides=(input_stride_0, input_stride_1), + # offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + # block_shape=(BLOCK_M, BLOCK_N), + # order=(1, 0) + # ) + # noise = tl.load(noise_block_ptr) + + offs_m = pid_dim0 * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_dim1 * BLOCK_N + tl.arange(0, BLOCK_N) + noise_offset = offs_m[:, None] * input_stride_0 + offs_n[None, :] * input_stride_1 + noise = tl.rand(0, noise_offset) + + output = _stochastic_rounding(output, noise, e_bit, m_bit) + + output = output.to(output_ptr.type.element_ty) + # tl.device_print("3: ", output) + output_t = tl.trans(output) + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(M, N), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + output_t_block_ptr = tl.make_block_ptr( + base=output_t_ptr, + shape=(N, M), + strides=(output_t_stride_0, output_t_stride_1), + offsets=(pid_dim1 * BLOCK_N, pid_dim0 * BLOCK_M), + block_shape=(BLOCK_N, BLOCK_M), + order=(1, 0), + ) + if not ONLY_TRANSPOSED: + tl.store(output_block_ptr, output, boundary_check=(0, 1)) + tl.store(output_t_block_ptr, output_t, boundary_check=(0, 1)) + + +def fp8_division_transpose(x, QB, fp8type, s_y=None, stochastic=False, only_transposed=False): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + + if stochastic: + # noise = torch.empty_like(x, dtype=torch.float32).uniform_(-0.5, 0.5) + noise = None + else: + noise = None + + # defining the input and output tensor + M, N = x.shape + SN = N // QB + + if isinstance(fp8type, str): + fp8type = convert_str_to_fp8[fp8type] + + y = torch.empty_like(x, dtype=fp8type) + y_t = torch.empty((N, M), dtype=fp8type, device=x.device) + fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value + e_bit, m_bit = convert_fp8_to_embit[fp8type] + + if s_y is None: + # print("Warning: do not specify s_y in fp8_division_transpose") + s_y = (x.abs().max() + SCALE_MIN_THRES) / fp8MaxValue + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_division_transpose_kernel[grid]( + y, + y_t, + x, + s_y, + noise, + M, + N, + SN, + QB, + fp8MaxValue, + e_bit, + m_bit, + x.stride(0), + x.stride(1), + y.stride(0), + y.stride(1), + y_t.stride(0), + y_t.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + STOCHASTIC=stochastic, + ONLY_TRANSPOSED=only_transposed, + ) + + if not only_transposed: + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + + return y, s_y, y_t # y_t is expected to be 2D tensor + else: + return y_t, s_y diff --git a/llava/model/coat/activation/real_quantization/_memory_io.py b/llava/model/coat/activation/real_quantization/_memory_io.py new file mode 100644 index 0000000000000000000000000000000000000000..b04d2a7afcf749b634cbd5d88c6d99b42030d5ff --- /dev/null +++ b/llava/model/coat/activation/real_quantization/_memory_io.py @@ -0,0 +1,180 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +CONST_BLOCK = 32 + +# The kernel with 1 load operation and 4 store operation +def get_configs_io_block(): + configs = [] + for nstages in [3, 4, 5, 6]: + for block_m in [32, 64, 128]: + for block_n in [32, 64, 128]: + for nwarps in [4, 8, 16, 32]: + configs.append( + triton.Config( + {"BLOCK_M": block_m, "BLOCK_N": block_n}, + num_stages=nstages, + num_warps=nwarps, + ) + ) + return configs + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.jit +def bench_memory_io_kernel_forward( + output_ptr, + input_ptr, + M, + N, + B: tl.constexpr, + input_stride_0, + input_stride_1, + output_stride_0, + output_stride_1, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_M = tl.cdiv(M, BLOCK_M) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + input = input.to(tl.float32) + + output = input * 2 + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(M, N), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + output = output.to(output_ptr.type.element_ty) + tl.store(output_block_ptr, output, boundary_check=(0, 1)) + + +def bench_memory_io_forward(x, B): + # defining the input and output tensor + M, N = x.shape + + y = torch.empty_like(x, dtype=x.dtype) + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + bench_memory_io_kernel_forward[grid]( + y, + x, + M, + N, + B, + x.stride(0), + x.stride(1), + y.stride(0), + y.stride(1), + ) + return y + + +configs = [] +for SL in [8192]: + configs.append( + triton.testing.Benchmark( # test different matrix size influence + x_names=["CDIM"], + x_vals=[1024, 2048, 4096, 8192], + line_arg="dtype", + line_vals=[torch.int8, torch.float16, torch.float32], + line_names=["float8", "float16", "float32"], + styles=[("blue", "-"), ("green", "-"), ("red", "-")], + ylabel="time-cost", + plot_name=f"INT8GELU", + args={"SL": SL, "B": CONST_BLOCK, "provider": "triton", "mode": "time-consuming"}, + ) + ) + + +@triton.testing.perf_report(configs) +def bench_load_store( + SL, CDIM, B, provider, dtype, mode="forward" +): # I only use triton as the provider, and mode when benchmarking + # create data + x = torch.randn(SL, CDIM, dtype=torch.float32).cuda() + x = x.to(dtype) + + quantiles = [0.5, 0.2, 0.8] + # utility functions + if provider == "triton": + + def y_fwd(): + bench_memory_io_forward(x, B) + + if provider == "torch": + torch_gelu = torch.nn.GELU() + + def y_fwd(): + return torch_gelu(x) + + # forward pass + if mode == "time-consuming": + convert_func = lambda ms: ms + ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=100) + # backward pass + if mode == "gbps": + convert_func = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6 + ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=100) + return convert_func(ms), convert_func(max_ms), convert_func(min_ms) + + +if __name__ == "__main__": + torch.manual_seed(0) + torch.set_printoptions(precision=8, linewidth=1600, sci_mode=False, edgeitems=3) + bench_load_store.run(print_data=True) diff --git a/llava/model/coat/activation/real_quantization/_quantize.py b/llava/model/coat/activation/real_quantization/_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..e1385b7dc42f25f0505a42cfecad48d9fa6eaa51 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/_quantize.py @@ -0,0 +1,176 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block + +"""Quantize Operator""" +"""Input uses 1 * 16 group quantization""" +"""Output uses 1 * 16 group quantization""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_quantize_kernel( + output_ptr, + output_scale_ptr, # output + input_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, # shape + input_stride_0, + input_stride_1, # input stride + output_stride_0, + output_stride_1, # output stride + s_output_stride_0, + s_output_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + input = input.to(tl.float32) + + output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) + + # Quantize Scale calculation + abs_output = tl.abs(output) + max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1)) + + # Quantize + output = tl.fdiv(output, scale_output) + + output = output.to(output_ptr.type.element_ty) + + scale_output = scale_output.to(output_scale_ptr.type.element_ty) + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN)) + output = tl.reshape(output, (BLOCK_M, BLOCK_N)) + + # debug + # gelu_output = input + # scale_output = scale_input + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(M, N), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output_ptr = tl.make_block_ptr( + base=output_scale_ptr, + shape=(M, SN), + strides=(s_output_stride_0, s_output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + tl.store(output_block_ptr, output, boundary_check=(0, 1)) + tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1)) + + +def fp8_quantize(x, QB, fp8type): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + + # defining the input and output tensor + M, N = x.shape + SN = N // QB + + if isinstance(fp8type, str): + fp8type = convert_str_to_fp8[fp8type] + y = torch.empty_like(x, dtype=fp8type) + s_y = torch.empty((M, SN), dtype=torch.bfloat16, device=x.device) + fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_quantize_kernel[grid]( + y, + s_y, + x, + M, + N, + SN, + QB, + fp8MaxValue, + x.stride(0), + x.stride(1), + y.stride(0), + y.stride(1), + s_y.stride(0), + s_y.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + ) + + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + s_y = s_y.reshape(BS, -1, s_y.shape[-1]) + + return y, s_y diff --git a/llava/model/coat/activation/real_quantization/_quantize_pertensor.py b/llava/model/coat/activation/real_quantization/_quantize_pertensor.py new file mode 100644 index 0000000000000000000000000000000000000000..46cadbc6e77e2e34c1fafec968aa17c5621da034 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/_quantize_pertensor.py @@ -0,0 +1,152 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from ._division import fp8_division +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_str_to_fp8, get_configs_io_block + +"""Per Tensor Quantize Operator""" +"""Input uses full precision""" +"""Output uses per tensor quantization""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_quantize_pertensor_kernel( + output_scale_ptr, # output + input_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, # shape + input_stride_0, + input_stride_1, # input stride + s_output_stride_0, + s_output_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + input = input.to(tl.float32) + + output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) + + # Quantize Scale calculation + abs_output = tl.abs(output) + max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1)) + + scale_output = scale_output.to(output_scale_ptr.type.element_ty) + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN)) + + scale_output_ptr = tl.make_block_ptr( + base=output_scale_ptr, + shape=(M, SN), + strides=(s_output_stride_0, s_output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1)) + + +def fp8_quantize_pertensor(x, QB, fp8type, stochastic=False): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + + # defining the input and output tensor + M, N = x.shape + SN = N // QB + + fp8type = convert_str_to_fp8[fp8type] + s_y = torch.empty((M, SN), dtype=torch.bfloat16, device=x.device) + fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_quantize_pertensor_kernel[grid]( + s_y, + x, + M, + N, + SN, + QB, + fp8MaxValue, + x.stride(0), + x.stride(1), + s_y.stride(0), + s_y.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + ) + + s_y_max = s_y.max() + y, s_y_max = fp8_division(x, QB, fp8type, s_y_max, stochastic=stochastic) # reuse the floating point output y1 + + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + s_y = s_y.reshape(BS, -1, s_y.shape[-1]) + + return y, s_y_max, s_y diff --git a/llava/model/coat/activation/real_quantization/_quantize_pertensor_transpose.py b/llava/model/coat/activation/real_quantization/_quantize_pertensor_transpose.py new file mode 100644 index 0000000000000000000000000000000000000000..4e91d01445c99928aa7f79227cdbc7706fee1981 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/_quantize_pertensor_transpose.py @@ -0,0 +1,155 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from ._division_transpose import fp8_division_transpose +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_str_to_fp8, get_configs_io_block + +"""Per Tensor Quantize and Transpose Operator""" +"""Input uses floating point tensor""" +"""Output uses per-tensor quantization, returns a non-transpose version and a transpose version""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_quantize_pertensor_transpose_kernel( + output_scale_ptr, # output + input_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, # shape + input_stride_0, + input_stride_1, # input stride + s_output_stride_0, + s_output_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + input = input.to(tl.float32) + + output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) + + # Quantize Scale calculation + abs_output = tl.abs(output) + max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1)) + + scale_output = scale_output.to(output_scale_ptr.type.element_ty) + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN)) + + scale_output_ptr = tl.make_block_ptr( + base=output_scale_ptr, + shape=(M, SN), + strides=(s_output_stride_0, s_output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1)) + + +def fp8_quantize_pertensor_transpose(x, QB, fp8type, transpose_output_2d=False, stochastic=False): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + + # defining the input and output tensor + M, N = x.shape + SN = N // QB + + fp8type = convert_str_to_fp8[fp8type] + s_y = torch.empty((M, SN), dtype=torch.bfloat16, device=x.device) + fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_quantize_pertensor_transpose_kernel[grid]( + s_y, + x, + M, + N, + SN, + QB, + fp8MaxValue, + x.stride(0), + x.stride(1), + s_y.stride(0), + s_y.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + ) + + s_y_max = s_y.max() + qy, s_y_max, qy_t = fp8_division_transpose( + x, QB, fp8type, s_y_max, stochastic=stochastic + ) # Stochastic Rounding happens here + + # Recover 2D to 3D + if batched: + qy = qy.reshape(BS, -1, qy.shape[-1]) + if not transpose_output_2d: + qy_t = qy_t.reshape(BS, -1, qy_t.shape[-1]) + + return qy, s_y_max, qy_t # y_t is expected to be 2D tensor diff --git a/llava/model/coat/activation/real_quantization/_transpose.py b/llava/model/coat/activation/real_quantization/_transpose.py new file mode 100644 index 0000000000000000000000000000000000000000..5b0b8ec2dc499d2bcd661e619d9f54e91daeac0c --- /dev/null +++ b/llava/model/coat/activation/real_quantization/_transpose.py @@ -0,0 +1,121 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from .common import get_configs_io_block + +"""Quantize Operator""" +"""Input uses 1 * 16 group quantization""" +"""Output uses 1 * 16 group quantization""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.jit +def _fp8_transpose_kernel( + output_ptr, # output + input_ptr, # input + M, + N, # shape + input_stride_0, + input_stride_1, # input stride + output_stride_0, + output_stride_1, # output stride + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + + output = tl.trans(input) + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(N, M), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim1 * BLOCK_N, pid_dim0 * BLOCK_M), + block_shape=(BLOCK_N, BLOCK_M), + order=(1, 0), + ) + + tl.store(output_block_ptr, output, boundary_check=(0, 1)) + + +def fp8_transpose(x, transpose_output_2d=False): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + + # defining the input and output tensor + M, N = x.shape + + y = torch.empty((N, M), dtype=x.dtype, device=x.device) + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_transpose_kernel[grid]( + y, + x, + M, + N, + x.stride(0), + x.stride(1), + y.stride(0), + y.stride(1), + ) + + # Recover 2D to 3D + if batched and not transpose_output_2d: + y = y.reshape(BS, -1, y.shape[-1]) + + return y diff --git a/llava/model/coat/activation/real_quantization/add_bwd.py b/llava/model/coat/activation/real_quantization/add_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..5951afec125048d8549fd1be0ae204df1a6b222a --- /dev/null +++ b/llava/model/coat/activation/real_quantization/add_bwd.py @@ -0,0 +1,205 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from ._division import fp8_division +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_str_to_fp8, get_configs_io_block + +"""Element-wise Add, useful for backward""" +"""Input1 (Residual) uses full-precision/BF16""" +"""Input2 (Backbone) uses full-precision/BF16""" +"""Output1 uses full-precision/BF16""" +"""Output2 uses per-tensor quantization""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_add_Ifp_Ifp_Ofp_Opt_kernel( + output1_ptr, # output + output2_scale_ptr, + input1_ptr, # input + input2_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, # shape + input1_stride_0, + input1_stride_1, # input1 stride + input2_stride_0, + input2_stride_1, # input2 stride + output1_stride_0, + output1_stride_1, # output stride + s_output2_stride_0, + s_output2_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # --- The first input --- + input1_block_ptr = tl.make_block_ptr( + base=input1_ptr, + shape=(M, N), + strides=(input1_stride_0, input1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + input1 = tl.load(input1_block_ptr) + input1 = input1.to(tl.float32) + input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB)) + + # --- The second input --- + input2_block_ptr = tl.make_block_ptr( + base=input2_ptr, + shape=(M, N), + strides=(input2_stride_0, input2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + input2 = tl.load(input2_block_ptr) + input2 = input2.to(tl.float32) + input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB)) + + # Actual Calculation of Add + add_output = input1 + input2 + + # Quantize the grad 1 - Scale calculation + abs_add_output = tl.abs(add_output) + max_val = tl.max(abs_add_output, axis=2) + SCALE_MIN_THRES + scale_output2 = max_val / fp8_max + scale_output2 = tl.reshape(scale_output2, (BLOCK_M, BLOCK_SN, 1)) + + # save the fp add output + fp_add_output = add_output.to(output1_ptr.type.element_ty) + fp_add_output = tl.reshape(fp_add_output, (BLOCK_M, BLOCK_N)) + + # pointers + output1_block_ptr = tl.make_block_ptr( + base=output1_ptr, + shape=(M, N), + strides=(output1_stride_0, output1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + tl.store(output1_block_ptr, fp_add_output, boundary_check=(0, 1)) + + # Quantize + scale_output2 = scale_output2.to(output2_scale_ptr.type.element_ty) + scale_output2 = tl.reshape(scale_output2, (BLOCK_M, BLOCK_SN)) + + # pointers + scale_output2_ptr = tl.make_block_ptr( + base=output2_scale_ptr, + shape=(M, SN), + strides=(s_output2_stride_0, s_output2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + tl.store(scale_output2_ptr, scale_output2, boundary_check=(0, 1)) + + +def fp8_add_Ifp_Ifp_Ofp_Opt(x1, x2, QB, fp8type, stochastic=False): # suppose x1 is full precision or BF16 + # Change batched 3D input to 2D + batched = False + if len(x1.shape) == 3: + assert len(x2.shape) == 3 + batched = True + BS = x1.shape[0] + x1 = x1.reshape(-1, x1.shape[-1]) + x2 = x2.reshape(-1, x2.shape[-1]) + + # defining the input and output tensor + M, N = x1.shape + SN = N // QB + assert x1.shape == x2.shape + + if isinstance(fp8type, str): + fp8type = convert_str_to_fp8[fp8type] + y1 = torch.empty_like(x1, dtype=torch.bfloat16) + s_y2 = torch.empty((M, SN), dtype=torch.bfloat16, device=x2.device) + fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_add_Ifp_Ifp_Ofp_Opt_kernel[grid]( + y1, + s_y2, + x1, + x2, + M, + N, + SN, + QB, + fp8MaxValue, + x1.stride(0), + x1.stride(1), + x2.stride(0), + x2.stride(1), + y1.stride(0), + y1.stride(1), + s_y2.stride(0), + s_y2.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + ) + + s_y2_max = s_y2.max() + qy2, s_y2_max = fp8_division(y1, QB, fp8type, s_y2_max, stochastic=stochastic) # reuse the floating point output y1 + + # Recover 2D to 3D + if batched: + y1 = y1.reshape(BS, -1, y1.shape[-1]) + qy2 = qy2.reshape(BS, -1, qy2.shape[-1]) + s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1]) + + return y1, (qy2, s_y2_max, s_y2) diff --git a/llava/model/coat/activation/real_quantization/add_fwd.py b/llava/model/coat/activation/real_quantization/add_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..90aed0b3573e4c1a21b4eecea959f6f537f61be8 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/add_fwd.py @@ -0,0 +1,219 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block + +"""Element-wise Add, used in forward pass""" +"""Input1 (Residual) uses full-precision/BF16""" +"""Input2 (Backbone) uses full-precision/BF16""" +"""Output1 uses full-precision/BF16""" +"""Output2 uses 1 * 16 group quantization""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_add_Ifp_Ifp_Ofp_Og16_kernel( + output1_ptr, # output + output2_ptr, + output2_scale_ptr, + input1_ptr, # input + input2_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, # shape + input1_stride_0, + input1_stride_1, # input1 stride + input2_stride_0, + input2_stride_1, # input2 stride + output1_stride_0, + output1_stride_1, # output stride + output2_stride_0, + output2_stride_1, # output stride + s_output2_stride_0, + s_output2_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # --- The first input --- + input1_block_ptr = tl.make_block_ptr( + base=input1_ptr, + shape=(M, N), + strides=(input1_stride_0, input1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + input1 = tl.load(input1_block_ptr) + input1 = input1.to(tl.float32) + input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB)) + + # --- The second input --- + input2_block_ptr = tl.make_block_ptr( + base=input2_ptr, + shape=(M, N), + strides=(input2_stride_0, input2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + input2 = tl.load(input2_block_ptr) + input2 = input2.to(tl.float32) + input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB)) + + # Actual Calculation of Add + add_output = input1 + input2 + + # Quantize the grad 1 - Scale calculation + abs_add_output = tl.abs(add_output) + max_val = tl.max(abs_add_output, axis=2) + SCALE_MIN_THRES + scale_output2 = max_val / fp8_max + scale_output2 = tl.reshape(scale_output2, (BLOCK_M, BLOCK_SN, 1)) + + # save the fp add output + fp_add_output = add_output.to(output1_ptr.type.element_ty) + fp_add_output = tl.reshape(fp_add_output, (BLOCK_M, BLOCK_N)) + + # pointers + output1_block_ptr = tl.make_block_ptr( + base=output1_ptr, + shape=(M, N), + strides=(output1_stride_0, output1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + tl.store(output1_block_ptr, fp_add_output) + + # Quantize + add_output = tl.fdiv(add_output, scale_output2) + scale_output2 = scale_output2.to(output2_scale_ptr.type.element_ty) + scale_output2 = tl.reshape(scale_output2, (BLOCK_M, BLOCK_SN)) + add_output = tl.reshape(add_output, (BLOCK_M, BLOCK_N)) + + add_output = add_output.to(output2_ptr.type.element_ty) + add_output = tl.reshape(add_output, (BLOCK_M, BLOCK_N)) + + # pointers + output2_block_ptr = tl.make_block_ptr( + base=output2_ptr, + shape=(M, N), + strides=(output2_stride_0, output2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output2_ptr = tl.make_block_ptr( + base=output2_scale_ptr, + shape=(M, SN), + strides=(s_output2_stride_0, s_output2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + tl.store(output2_block_ptr, add_output, boundary_check=(0, 1)) + tl.store(scale_output2_ptr, scale_output2, boundary_check=(0, 1)) + + +def fp8_add_Ifp_Ifp_Ofp_Og16(x1, x2, fp8type, QB): # suppose x1 is full precision or BF16 + # Change batched 3D input to 2D + batched = False + if len(x1.shape) == 3: + batched = True + BS = x1.shape[0] + x1 = x1.reshape(-1, x1.shape[-1]) + x2 = x2.reshape(-1, x2.shape[-1]) + + # defining the input and output tensor + M, N = x1.shape + SN = int(N / QB) # assume the shape of quantization block size is always 1 * G + assert x1.shape == x2.shape + + y1 = torch.empty_like(x1, dtype=torch.bfloat16) + y2 = torch.empty_like(x2, dtype=fp8type) + s_y2 = torch.empty((M, SN), dtype=torch.bfloat16, device=x2.device) + fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_add_Ifp_Ifp_Ofp_Og16_kernel[grid]( + y1, + y2, + s_y2, + x1, + x2, + M, + N, + SN, + QB, + fp8MaxValue, + x1.stride(0), + x1.stride(1), + x2.stride(0), + x2.stride(1), + y1.stride(0), + y1.stride(1), + y2.stride(0), + y2.stride(1), + s_y2.stride(0), + s_y2.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + ) + + # Recover 2D to 3D + if batched: + y1 = y1.reshape(BS, -1, y1.shape[-1]) + y2 = y2.reshape(BS, -1, y2.shape[-1]) + s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1]) + + return y1, (y2, s_y2) diff --git a/llava/model/coat/activation/real_quantization/common.py b/llava/model/coat/activation/real_quantization/common.py new file mode 100644 index 0000000000000000000000000000000000000000..00de715cf3e80eddf74ea401e1bbc9811c3e4970 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/common.py @@ -0,0 +1,59 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton + +SCALE_MIN_THRES = 1e-10 + +FP8_MAX_VALUE = { + torch.float8_e4m3fn: 448.0, + torch.float8_e5m2: 57344.0, +} + +convert_str_to_fp8 = {"E4M3": torch.float8_e4m3fn, "E5M2": torch.float8_e5m2} +convert_fp8_to_embit = { + torch.float8_e4m3fn: (4.0, 3.0), + torch.float8_e5m2: (5.0, 2.0), +} + + +def get_configs_io_block(): + configs = [] + for nstages in [3, 4, 5]: + for block_m in [32, 64]: + for block_n in [64, 128]: + for nwarps in [4, 8]: + configs.append( + triton.Config( + {"BLOCK_M": block_m, "BLOCK_N": block_n}, + num_stages=nstages, + num_warps=nwarps, + ) + ) + return configs + + +# from .common import SCALE_MIN_THRES, FP8_MAX_VALUE +# SCALE_MIN_THRES: tl.constexpr, +# + SCALE_MIN_THRES +# SCALE_MIN_THRES=SCALE_MIN_THRES, diff --git a/llava/model/coat/activation/real_quantization/fp8linear.py b/llava/model/coat/activation/real_quantization/fp8linear.py new file mode 100644 index 0000000000000000000000000000000000000000..32531fcb4bb63e4720e3fc48fb08b1f2d82d54db --- /dev/null +++ b/llava/model/coat/activation/real_quantization/fp8linear.py @@ -0,0 +1,250 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import os +import time +from copy import deepcopy +from dataclasses import dataclass + +import matplotlib.pyplot as plt +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd.function import Function + +from ..utils import quant_get_local_rank +from ._division_transpose import fp8_division_transpose +from ._quantize_pertensor_transpose import fp8_quantize_pertensor_transpose +from .linear import fp8_linear_backward, fp8_linear_forward + + +@dataclass +class DefaultArgs: + fabit: int + fwbit: int + bobit: int + + +class FP8Linear(nn.Linear): + def __init__(self, in_features, out_features, bias=True, device=None, args=None, layer_idx=0): + super().__init__(in_features, out_features, bias, device) + + if args is None: # I do not want to pass a new argument to OLMo so just use this method + args = DefaultArgs( + fabit=os.environ["FABIT_FP8Linear"], + fwbit=os.environ["FWBIT_FP8Linear"], + bobit=os.environ["BOBIT_FP8Linear"], + ) + self.args = deepcopy(args) + + if quant_get_local_rank() == 0: + print(f"[qlinear debug] Apply QLinear, {layer_idx}") + + self.layer_idx = layer_idx + self.layer_name = None + + def forward(self, Input): + if self.training: + # if False: + output = QuantLinearTE.apply(Input, self.weight, self.bias, self.args, self.layer_name) + else: + output = F.linear(Input, self.weight, self.bias) + + return output + + +# if int(os.environ.get("LOCAL_RANK")) == 0: +# import IPython +# IPython.embed() +# else: +# import time +# time.sleep(1000) + +# class QuantLinearTE(Function): +# @staticmethod +# def forward(ctx, input, weight, bias, args, layer_type): +# ctx.saved = input, weight, bias, args, layer_type +# return F.linear(input, weight, bias) + +# @staticmethod +# def backward(ctx, grad_output): +# input, weight, bias, args, layer_type = ctx.saved + +# C_in = input.shape[-1] +# C_out = grad_output.shape[-1] + +# grad_output_flatten = grad_output.reshape(-1, C_out) +# input_flatten = input.reshape(-1, C_in) + +# if grad_output_flatten.dtype == input_flatten.dtype: +# grad_weight = grad_output_flatten.t().mm(input_flatten) +# else: +# grad_weight = grad_output_flatten.float().t().mm(input_flatten) + +# if grad_output_flatten.dtype == weight.dtype: +# grad_input = grad_output_flatten.mm(weight) +# else: +# grad_input = grad_output_flatten.float().mm(weight) + +# if bias is not None: +# grad_bias = grad_output_flatten.sum(0) +# else: +# grad_bias = None + +# grad_input_transform = grad_input.reshape(input.size()) + +# return grad_input_transform, grad_weight, grad_bias, None, None + + +class QuantLinearTE(Function): + @staticmethod + @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.bfloat16) + def forward(ctx, input, weight, bias, args, layer_name): + + time_bench = os.getenv("TIME_BENCH") + + if time_bench: + start_1 = torch.cuda.Event(enable_timing=True) + start_1.record() + + # Qinput, Iscale, Qinput_t = fp8_division_transpose(input, 16, args.fabit) + Qinput, Iscale, Qinput_t = fp8_quantize_pertensor_transpose(input, 16, args.fabit, transpose_output_2d=True) + + if time_bench: + end_1 = torch.cuda.Event(enable_timing=True) + end_1.record() + start_2 = torch.cuda.Event(enable_timing=True) + start_2.record() + + # Qweight, Wscale, Qweight_t = fp8_division_transpose(weight, 16, args.fwbit) + Qweight, Wscale, Qweight_t = fp8_quantize_pertensor_transpose(weight, 16, args.fwbit, transpose_output_2d=True) + + if time_bench: + end_2 = torch.cuda.Event(enable_timing=True) + end_2.record() + start_3 = torch.cuda.Event(enable_timing=True) + start_3.record() + + ctx.saved = Qinput_t, Iscale, Qweight_t, Wscale, bias, args, layer_name + fc_output = fp8_linear_forward(Qinput, Iscale, Qweight, Wscale, False, 0, bias) + + if time_bench: + end_3 = torch.cuda.Event(enable_timing=True) + end_3.record() + start_4 = torch.cuda.Event(enable_timing=True) + start_4.record() + + output = F.linear(input, weight, bias) + + end_4 = torch.cuda.Event(enable_timing=True) + end_4.record() + + torch.cuda.synchronize() + if quant_get_local_rank() == 0: + print( + f"[Forward] Part 1: {start_1.elapsed_time(end_1):.6f} ms | Part 2: {start_2.elapsed_time(end_2):.6f} ms | Part 3: {start_3.elapsed_time(end_3):.6f} ms | " + f"FP8: {start_1.elapsed_time(end_3):.6f} | BF16: {start_4.elapsed_time(end_4):.6f} | Input shape: {input.shape} | Weight shape: {weight.shape}" + ) + + return fc_output + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(ctx, grad_output): + Qinput_t, Iscale, Qweight_t, Wscale, bias, args, layer_name = ctx.saved + + time_bench = os.getenv("TIME_BENCH") + if time_bench: + start_1 = torch.cuda.Event(enable_timing=True) + start_1.record() + + # Qgrad_output, Gscale, Qgrad_output_t = fp8_division_transpose(grad_output, 16, args.bobit, stochastic=False) + Qgrad_output, Gscale, Qgrad_output_t = fp8_quantize_pertensor_transpose( + grad_output, 16, args.bobit, stochastic=False, transpose_output_2d=True + ) + + if time_bench: + end_1 = torch.cuda.Event(enable_timing=True) + end_1.record() + start_2 = torch.cuda.Event(enable_timing=True) + start_2.record() + + grad_input, grad_weight = fp8_linear_backward( + Qinput_t, + Iscale, + Qgrad_output, + Gscale, + Qgrad_output_t, + Qweight_t, + Wscale, + 16, + bias, + stochastic=False, + dgrad_quantize=False, + ) + + if time_bench: + end_2 = torch.cuda.Event(enable_timing=True) + end_2.record() + start_3 = torch.cuda.Event(enable_timing=True) + start_3.record() + + if bias is not None: + grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0) + else: + grad_bias = None + + if time_bench: + end_3 = torch.cuda.Event(enable_timing=True) + end_3.record() + + # ========== BF16 ========== + C_in = Qinput_t.shape[0] + C_out = grad_output.shape[-1] + grad_output_flatten = grad_output.reshape(-1, C_out) + input_flatten = Qinput_t.t().reshape(-1, C_in).to(torch.bfloat16) + weight = Qweight_t.t().to(torch.bfloat16) + + start_4 = torch.cuda.Event(enable_timing=True) + start_4.record() + + if grad_output_flatten.dtype == input_flatten.dtype: + _grad_weight = grad_output_flatten.t().mm(input_flatten) + else: + _grad_weight = grad_output_flatten.float().t().mm(input_flatten) + + if grad_output_flatten.dtype == weight.dtype: + _grad_input = grad_output_flatten.mm(weight) + else: + _grad_input = grad_output_flatten.float().mm(weight) + + end_4 = torch.cuda.Event(enable_timing=True) + end_4.record() + + torch.cuda.synchronize() + if quant_get_local_rank() == 0: + print( + f"[Backward] Part 1: {start_1.elapsed_time(end_1):.6f} ms | Part 2: {start_2.elapsed_time(end_2):.6f} ms | Part 3: {start_3.elapsed_time(end_3):.6f} ms | " + f"FP8: {start_1.elapsed_time(end_3):.6f} | BF16: {start_4.elapsed_time(end_4):.6f} | Input shape: {Qinput_t.shape} | Weight shape: {weight.shape}" + ) + + return grad_input, grad_weight, grad_bias, None, None diff --git a/llava/model/coat/activation/real_quantization/func_layernorm_noparam.py b/llava/model/coat/activation/real_quantization/func_layernorm_noparam.py new file mode 100644 index 0000000000000000000000000000000000000000..c3a1f57f3d1db8b19da687fe823a83f35d97d7d6 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/func_layernorm_noparam.py @@ -0,0 +1,303 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Layer Normalization +==================== +In this tutorial, you will write a high-performance layer normalization +kernel that runs faster than the PyTorch implementation. + +In doing so, you will learn about: + +* Implementing backward pass in Triton. + +* Implementing parallel reduction in Triton. + +""" + +import torch +import triton +import triton.language as tl + +from ._division import _stochastic_rounding +from ._division_transpose import fp8_division_transpose +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit + +"""FP8 LayerNorm. Forward + Backward""" +"""Forward: input uses 1 * 16 quantization""" +"""Forward: output use per-tensor quantization.""" +"""Backward: input uses full-precision/BF16.""" +"""Backward: output uses full-precision/BF16""" +"""Support 3D input, but need to first flatten to 2D to perform calculation.""" + + +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["N"] // args["QB"], + "BLOCK_SN2": lambda args: args["N2"] // args["QB"], + } +) +@triton.jit +def _layer_norm_fwd_fused( + X, # pointer to the input + SX, # pointer to the scale of input + Y, # pointer to the output + SY, # pointer to the scale of output + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + scale_stride, # how much to increase the pointer when moving by 1 row + N: tl.constexpr, # number of columns in X, + N2: tl.constexpr, # number of columns in X, + SN: tl.constexpr, + SN2: tl.constexpr, + QB: tl.constexpr, + eps, # epsilon to avoid division by zero + fp8_max, + SCALE_MIN_THRES: tl.constexpr, + BLOCK_SN: tl.constexpr, + BLOCK_SN2: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + SX += row * scale_stride + + mean = 0 + cols = tl.arange(0, N2) + scale_cols = tl.arange(0, SN2) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + scale_x = tl.load(SX + scale_cols, mask=scale_cols < SN, other=0.0).to(tl.float32) + + # Dequantize and swish calculation + scale_x = tl.reshape(scale_x, (BLOCK_SN2, 1)) + x = tl.reshape(x, (BLOCK_SN2, QB)) + x = x * scale_x + x = tl.reshape(x, (N2,)) + + # Compute mean and Variance + mean = tl.sum(x, axis=0) / N + # Compute variance + _var = (x - mean) * (x - mean) + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Write mean / rstd + tl.store(Mean + row, mean) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + x_hat = (x - mean) * rstd + + # Scale calculation + abs_x_hat = tl.abs(x_hat) + max_val = tl.max(abs_x_hat, axis=0) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + scale_output = scale_output.to(SY.type.element_ty) + + tl.store(SY + row, scale_output) + + # Write output + tl.store(Y + cols, x_hat, mask=cols < N) + + +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["N"] // args["QB"], + "BLOCK_SN2": lambda args: args["N2"] // args["QB"], + } +) +@triton.jit +def _layer_norm_bwd_dx_fused( + DX, # pointer to the input gradient + DY, # pointer to the output gradient + X, # pointer to the input + SX, # pointer to the input + noise_ptr, # noise for stochastic + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + scale_stride, # how much to increase the pointer when moving by 1 row + N: tl.constexpr, # number of columns in X + N2: tl.constexpr, # number of columns in X + SN: tl.constexpr, + SN2: tl.constexpr, + QB: tl.constexpr, + SCALE_MIN_THRES: tl.constexpr, + STOCHASTIC: tl.constexpr, + BLOCK_SN: tl.constexpr, + BLOCK_SN2: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row = tl.program_id(0) + cols = tl.arange(0, N2) + scale_cols = tl.arange(0, SN2) + mask = cols < N + X += row * stride + DY += row * stride + DX += row * stride + SX += row * scale_stride + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + scale_x = tl.load(SX + scale_cols, mask=scale_cols < SN, other=0.0).to(tl.float32) + scale_x = tl.reshape(scale_x, (BLOCK_SN2, 1)) + x = tl.reshape(x, (BLOCK_SN2, QB)) + x = x * scale_x + x = tl.reshape(x, N2) + + dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32) + + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd + xhat = tl.where(mask, xhat, 0.0) + dy = tl.where(mask, dy, 0.0) + c1 = tl.sum(xhat * dy, axis=0) / N + c2 = tl.sum(dy, axis=0) / N + dx = (dy - (xhat * c1 + c2)) * rstd + + if STOCHASTIC: + # noise_ptr += row * stride + # noise_block_ptr = noise_ptr + cols + # noise = tl.load(noise_block_ptr, mask=mask, other=0.) + + noise_offset = row * stride + cols + noise = tl.rand(0, noise_offset) + + dx = _stochastic_rounding(dx, noise, e_bit, m_bit) + + dx = dx.to(DX.type.element_ty) + + # Write dx + tl.store(DX + cols, dx, mask=mask) + + +def fp8_layernorm_noparam_forward(x, s_x, QB, eps, transpose_output_2d=False): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + assert len(s_x.shape) == 3 + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + s_x = s_x.reshape(-1, s_x.shape[-1]) + + # allocate output + M, N = x.shape + _, SN = s_x.shape + y = torch.empty_like(x, dtype=torch.float32) + s_y = torch.empty( + (M,), dtype=torch.bfloat16, device=x.device + ) # We do this because we apply per-tensor quantization for it afterwards + mean = torch.empty((M,), dtype=torch.float32, device=x.device) + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + # heuristics for number of warps + num_warps = 8 + fp8MaxValue = FP8_MAX_VALUE[x.dtype] + + N2 = triton.next_power_of_2(N) + SN2 = N2 // QB + # enqueue kernel + _layer_norm_fwd_fused[(M,)]( # + x, + s_x, + y, + s_y, + mean, + rstd, # + x.stride(0), + s_x.stride(0), + N, + N2, + SN, + SN2, + QB, + eps, + fp8MaxValue, + SCALE_MIN_THRES=SCALE_MIN_THRES, # + num_warps=num_warps, + num_ctas=1, + ) + # reduction + s_y_max = s_y.max() + qy, s_y_max, qy_t = fp8_division_transpose(y, QB, x.dtype, s_y_max) + + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + qy = qy.reshape(BS, -1, y.shape[-1]) + if not transpose_output_2d: + qy_t = qy_t.reshape(BS, -1, y.shape[-1]) + + return qy, s_y_max, qy_t, (mean, rstd, num_warps) + + +def fp8_layernorm_noparam_backward(x, s_x, g, QB, m, v, num_warps, stochastic=False): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + assert len(s_x.shape) == 3 + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + s_x = s_x.reshape(-1, s_x.shape[-1]) + g = g.reshape(-1, g.shape[-1]) + + if stochastic: + # noise = torch.empty_like(g, dtype=torch.float32).uniform_(-0.5, 0.5) + noise = None + else: + noise = None + + # heuristics for amount of parallel reduction stream for DW/DB + dx = torch.empty_like(g, dtype=torch.bfloat16) + # enqueue kernel using forward pass heuristics + # also compute partial sums for DW and DB + M, N = g.shape + _, SN = s_x.shape + + N2 = triton.next_power_of_2(N) + SN2 = triton.next_power_of_2(SN) + _layer_norm_bwd_dx_fused[(M,)]( # + dx, + g, + x, + s_x, + noise, + m, + v, # + x.stride(0), + s_x.stride(0), + N, + N2, + SN, + SN2, + QB, + SCALE_MIN_THRES=SCALE_MIN_THRES, + STOCHASTIC=stochastic, + num_warps=num_warps, + ) + + if batched: + dx = dx.reshape(BS, -1, dx.shape[-1]) + + return dx diff --git a/llava/model/coat/activation/real_quantization/func_quantize.py b/llava/model/coat/activation/real_quantization/func_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..f63942146890a02132b431a10df39f22a65ebec4 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/func_quantize.py @@ -0,0 +1,102 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from copy import deepcopy + +import torch +import torch.nn as nn + +from ._quantize import fp8_quantize +from ._quantize_pertensor import fp8_quantize_pertensor + + +class Coat_quantize_bgn(nn.Module): + def __init__(self, args=None, layer_type=""): + super().__init__() + self.args = deepcopy(args) + self.fp8type = self.args.fabit + self.layer_type = layer_type + + def forward(self, input): + if self.training: + return Coat_quantize_bgn_func.apply(input, self.args.group_size, self.fp8type) + else: + return input, None, None + + +class Coat_quantize_bgn_func(torch.autograd.Function): + @staticmethod + def forward(ctx, input, group_size, fp8type): + """ + (Qoutput, Oscale) uses 1 * 16 quantization + """ + Qoutput, Oscale = fp8_quantize(input, group_size, fp8type) + # For autograd + Qoutput = Qoutput.view(torch.float8_e4m3fn) + ctx.saved = group_size + return input, Qoutput, Oscale + + @staticmethod + def backward(ctx, grad_output, Qgrad_output, Gscale): + """ + (Qgrad_output, Gscale) uses 1 * 16 quantization + """ + return grad_output, None, None + + +class Coat_quantize_end(nn.Module): + def __init__(self, args=None, layer_type=""): + super().__init__() + self.args = deepcopy(args) + self.fp8type = self.args.babit + self.layer_type = layer_type + + def forward(self, input, Qinput, Iscale): + if self.training: + return Coat_quantize_end_func.apply(input, Qinput, Iscale, self.args.group_size, self.fp8type) + else: + return input + + +class Coat_quantize_end_func(torch.autograd.Function): + @staticmethod + def forward(ctx, input, Qinput, Iscale, group_size, fp8type): + """ + (Qinput, Iscale) uses 1 * 16 quantization + """ + ctx.saved = group_size, fp8type + + return input + + @staticmethod + def backward(ctx, grad_output): + """ + (Qgrad_output, Gscale) uses per-tensor quantization + """ + + group_size, fp8type = ctx.saved + Qgrad_output, Gscale, Gscale_g16 = fp8_quantize_pertensor(grad_output, group_size, fp8type, stochastic=False) + + # For autograd + Qgrad_output = Qgrad_output.view(torch.float8_e4m3fn) + + return grad_output, Qgrad_output, Gscale_g16, None, None diff --git a/llava/model/coat/activation/real_quantization/func_rmsnorm.py b/llava/model/coat/activation/real_quantization/func_rmsnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..e7a8d38e2909d7a45a7792e9b1502af5aa7dee95 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/func_rmsnorm.py @@ -0,0 +1,345 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Layer Normalization +==================== +In this tutorial, you will write a high-performance layer normalization +kernel that runs faster than the PyTorch implementation. + +In doing so, you will learn about: + +* Implementing backward pass in Triton. + +* Implementing parallel reduction in Triton. + +""" + +import torch +import triton +import triton.language as tl + +from ._division_transpose import fp8_division_transpose +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES + +"""RMSNorm Forward + Backward""" +"""Forward: Input uses 1 * 16 group quantization""" +"""Forward: Output uses per-tensor quantization""" +"""Backward: Input uses full-precision/BF16.""" +"""Backward: Output uses full-precision/BF16.""" + +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["N"] // args["QB"], + "BLOCK_SN2": lambda args: args["N2"] // args["QB"], + } +) +@triton.jit +def _rms_norm_fwd_fused( + X, # pointer to the input + SX, # pointer to the scale of input + Y, # pointer to the output + SY, # pointer to the scale of output + W, # Weight + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + scale_stride, # how much to increase the pointer when moving by 1 row + N: tl.constexpr, # number of columns in X, + N2: tl.constexpr, # number of columns in X, + SN: tl.constexpr, + SN2: tl.constexpr, + QB: tl.constexpr, + eps, # epsilon to avoid division by zero + fp8_max, + SCALE_MIN_THRES: tl.constexpr, + BLOCK_SN: tl.constexpr, + BLOCK_SN2: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + SX += row * scale_stride + + cols = tl.arange(0, N2) + scale_cols = tl.arange(0, SN2) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + scale_x = tl.load(SX + scale_cols, mask=scale_cols < SN, other=0.0).to(tl.float32) + + # Dequantize and swish calculation + scale_x = tl.reshape(scale_x, (BLOCK_SN2, 1)) + x = tl.reshape(x, (BLOCK_SN2, QB)) + x = x * scale_x + x = tl.reshape(x, N2) + + # Compute variance + _var = x * x + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + + # Write mean / rstd + tl.store(Rstd + row, rstd) + + # Normalize and apply linear transformation + w = tl.load(W + cols, mask=cols < N, other=0.0) + x_hat = x * rstd + y = x_hat * w + + # Scale calculation + abs_y = tl.abs(y) + max_val = tl.max(abs_y) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + tl.store(SY + row, scale_output) + + # Write output + tl.store(Y + cols, y, mask=cols < N) + + +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["N"] // args["QB"], + "BLOCK_SN2": lambda args: args["N2"] // args["QB"], + } +) +@triton.jit +def _rms_norm_bwd_dx_fused( + DX, # pointer to the input gradient + DY, # pointer to the output gradient + DW, + X, # pointer to the input + SX, # pointer to the input + W, # weight + Rstd, # pointer to the 1/std + Lock, + stride, # how much to increase the pointer when moving by 1 row + scale_stride, # how much to increase the pointer when moving by 1 row + N: tl.constexpr, # number of columns in X, + N2: tl.constexpr, # number of columns in X, + SN: tl.constexpr, + SN2: tl.constexpr, + QB: tl.constexpr, + SCALE_MIN_THRES: tl.constexpr, + BLOCK_SN: tl.constexpr, + BLOCK_SN2: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row = tl.program_id(0) + cols = tl.arange(0, N2) + scale_cols = tl.arange(0, SN2) + mask = cols < N + X += row * stride + DY += row * stride + DX += row * stride + SX += row * scale_stride + + # Offset locks and weights/biases gradient pointer for parallel reduction + lock_id = row % GROUP_SIZE_M + Lock += lock_id + Count = Lock + GROUP_SIZE_M + DW = DW + lock_id * N + cols + + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + scale_x = tl.load(SX + scale_cols, mask=scale_cols < SN, other=0.0).to(tl.float32) + scale_x = tl.reshape(scale_x, (BLOCK_SN2, 1)) + x = tl.reshape(x, (BLOCK_SN2, QB)) + x = x * scale_x + x = tl.reshape(x, N2) + dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32) + + # Load weight + w = tl.load(W + cols, mask=cols < N, other=0.0).to(tl.float32) + rstd = tl.load(Rstd + row).to(tl.float32) + + # Compute dx + xhat = x * rstd + wdy = w * dy + xhat = tl.where(mask, xhat, 0.0) + wdy = tl.where(mask, wdy, 0.0) + c1 = tl.sum(xhat * wdy, axis=0) / N + dx = (wdy - (xhat * c1)) * rstd # layer norm have c2 term, rmsnorm do not + + dx = dx.to(DX.type.element_ty) + + # Write dx + tl.store(DX + cols, dx, mask=mask) + + # Accumulate partial sums for dw/db + partial_dw = (dy * xhat).to(w.dtype) + while tl.atomic_cas(Lock, 0, 1) == 1: + pass + count = tl.load(Count) + # First store doesn't accumulate + if count == 0: + tl.atomic_xchg(Count, 1) + else: + partial_dw += tl.load(DW, mask=mask) + tl.store(DW, partial_dw, mask=mask) + # Release the lock + tl.atomic_xchg(Lock, 0) + + +@triton.jit +def _rms_norm_bwd_dwdb( + DW, # pointer to the partial sum of weights gradient + FINAL_DW, # pointer to the weights gradient + M, # GROUP_SIZE_M + N, # number of columns + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + # Map the program id to the elements of DW and DB it should compute. + pid = tl.program_id(0) + cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # Iterate through the rows of DW and DB to sum the partial sums. + for i in range(0, M, BLOCK_SIZE_M): + rows = i + tl.arange(0, BLOCK_SIZE_M) + mask = (rows[:, None] < M) & (cols[None, :] < N) + offs = rows[:, None] * N + cols[None, :] + dw += tl.load(DW + offs, mask=mask, other=0.0) + # Write the final sum to the output. + sum_dw = tl.sum(dw, axis=0) + tl.store(FINAL_DW + cols, sum_dw, mask=cols < N) + + +def fp8_rmsnorm_forward(x, s_x, w, QB, eps, transpose_output_2d=False): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + assert len(s_x.shape) == 3 + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + s_x = s_x.reshape(-1, s_x.shape[-1]) + + # allocate output + M, N = x.shape + _, SN = s_x.shape + y = torch.empty_like(x, dtype=torch.bfloat16) + s_y = torch.empty((M,), dtype=torch.bfloat16, device=x.device) + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + # heuristics for number of warps + num_warps = 8 + fp8MaxValue = FP8_MAX_VALUE[x.dtype] + + N2 = triton.next_power_of_2(N) + SN2 = N2 // QB + + # import os + # if int(os.environ.get("LOCAL_RANK")) == 7: + # print(x.device, x.shape, x.dtype, s_x.shape, s_x.dtype, y.shape, y.dtype, s_y.shape, s_y.dtype, w.shape, w.dtype, rstd.shape, rstd.dtype, x.stride(0), s_x.stride(0), N, N2, SN, SN2, QB, "\n") + # enqueue kernel + _rms_norm_fwd_fused[(M,)]( # + x, + s_x, + y, + s_y, + w, + rstd, # + x.stride(0), + s_x.stride(0), + N, + N2, + SN, + SN2, + QB, + eps, + fp8MaxValue, + SCALE_MIN_THRES=SCALE_MIN_THRES, # + num_warps=num_warps, + num_ctas=1, + ) + + # reduction + s_y_max = s_y.max() + qy, s_y_max, qy_t = fp8_division_transpose(y, QB, x.dtype, s_y_max) + + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + qy = qy.reshape(BS, -1, y.shape[-1]) + if not transpose_output_2d: + qy_t = qy_t.reshape(BS, -1, y.shape[-1]) + + return qy, s_y_max, qy_t, (w.clone(), rstd, num_warps) + + +def fp8_rmsnorm_backward(x, s_x, g, w, v, QB, num_warps): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + assert len(s_x.shape) == 3 + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + s_x = s_x.reshape(-1, s_x.shape[-1]) + g = g.reshape(-1, g.shape[-1]) + + # enqueue kernel using forward pass heuristics + # also compute partial sums for DW and DB + M, N = g.shape + _, SN = s_x.shape + + GROUP_SIZE_M = 128 + # heuristics for amount of parallel reduction stream for DW/DB + locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device=w.device) + _dw = torch.zeros((GROUP_SIZE_M, N), dtype=w.dtype, device=w.device) + dw = torch.empty((N,), dtype=w.dtype, device=w.device) + + dx = torch.empty_like(g, dtype=torch.bfloat16) + + N2 = triton.next_power_of_2(N) + SN2 = triton.next_power_of_2(SN) + _rms_norm_bwd_dx_fused[(M,)]( # + dx, + g, + _dw, + x, + s_x, + w, + v, + locks, # + x.stride(0), + s_x.stride(0), + N, + N2, + SN, + SN2, + QB, + SCALE_MIN_THRES=SCALE_MIN_THRES, + num_warps=num_warps, + GROUP_SIZE_M=GROUP_SIZE_M, + ) + + if batched: + dx = dx.reshape(BS, -1, dx.shape[-1]) + + grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])] + # accumulate partial sums in separate kernel + _rms_norm_bwd_dwdb[grid](_dw, dw, min(GROUP_SIZE_M, M), N, BLOCK_SIZE_M=32, BLOCK_SIZE_N=128, num_ctas=1) # # + + return dx, dw diff --git a/llava/model/coat/activation/real_quantization/gelu_bwd.py b/llava/model/coat/activation/real_quantization/gelu_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..8772707f7f5e6ef9b68b9c29f2b1b329f2c3177d --- /dev/null +++ b/llava/model/coat/activation/real_quantization/gelu_bwd.py @@ -0,0 +1,235 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from ._division import fp8_division +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_str_to_fp8, get_configs_io_block + +"""GELU Activation Backward""" +"""Input uses 1 * 16 group quantization""" +"""Grad uses full-precision/BF16""" +"""Output uses per-tensor quantization""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_gelu_backward_kernel( + output_ptr, + output_scale_ptr, # output + input_ptr, + input_scale_ptr, # input + grad_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, # shape + input_stride_0, + input_stride_1, # input stride + s_input_stride_0, + s_input_stride_1, # scale of input stride + grad_stride_0, + grad_stride_1, # input stride + output_stride_0, + output_stride_1, # output stride + s_output_stride_0, + s_output_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input_ptr = tl.make_block_ptr( + base=input_scale_ptr, + shape=(M, SN), + strides=(s_input_stride_0, s_input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + scale_input = tl.load(scale_input_ptr) + + input = input.to(tl.float32) + scale_input = scale_input.to(tl.float32) + + # Dequantize and gelu calculation + scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1)) + input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) + input = input * scale_input + + # pointers of gradient + grad_block_ptr = tl.make_block_ptr( + base=grad_ptr, + shape=(M, N), + strides=(grad_stride_0, grad_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + grad = tl.load(grad_block_ptr) + grad = grad.to(tl.float32) + grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB)) + + # Actual Calculation of GELU's backward + pi = float(torch.pi) + cdf = (1.0 + tl.math.erf(input / tl.math.sqrt(2.0))) / 2 + exp = input * tl.exp(-libdevice.pow(input, 2) / 2) / tl.sqrt(2 * pi) + gelu_output = cdf + exp + + gelu_output = gelu_output * grad + + # Quantize Scale calculation + abs_output = tl.abs(gelu_output) + max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1)) + + # Quantize + # gelu_output = tl.fdiv(gelu_output, scale_output) + gelu_output = gelu_output.to(output_ptr.type.element_ty) + + scale_output = scale_output.to(output_scale_ptr.type.element_ty) + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN)) + gelu_output = tl.reshape(gelu_output, (BLOCK_M, BLOCK_N)) + + # debug + # gelu_output = input + # scale_output = scale_input + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(M, N), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output_ptr = tl.make_block_ptr( + base=output_scale_ptr, + shape=(M, SN), + strides=(s_output_stride_0, s_output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + tl.store(output_block_ptr, gelu_output, boundary_check=(0, 1)) + tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1)) + + +def fp8_gelu_backward(x, s_x, g, QB, fp8type): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + assert len(s_x.shape) == 3 + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + s_x = s_x.reshape(-1, s_x.shape[-1]) + g = g.reshape(-1, g.shape[-1]) + + # defining the input and output tensor + M, N = x.shape + _, SN = s_x.shape # assume the shape of quantization block size is always 1 * G + + if isinstance(fp8type, str): + fp8type = convert_str_to_fp8[fp8type] + y = torch.empty_like(g, dtype=torch.bfloat16) + s_y = torch.empty_like(s_x, dtype=torch.bfloat16) + fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_gelu_backward_kernel[grid]( + y, + s_y, + x, + s_x, + g, + M, + N, + SN, + QB, + fp8MaxValue, + x.stride(0), + x.stride(1), + s_x.stride(0), + s_x.stride(1), + g.stride(0), + g.stride(1), + y.stride(0), + y.stride(1), + s_y.stride(0), + s_y.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + ) + + # Per-tensor quantization + s_y_max = s_y.max() + qy, s_y_max = fp8_division(y, QB, fp8type, s_y_max) + + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + qy = qy.reshape(BS, -1, qy.shape[-1]) + s_y = s_y.reshape(BS, -1, s_y.shape[-1]) + + return qy, s_y_max diff --git a/llava/model/coat/activation/real_quantization/gelu_bwd_legacy.py b/llava/model/coat/activation/real_quantization/gelu_bwd_legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..6f2b36bcb8c51d33392cd740c3dcbf6de4c9f082 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/gelu_bwd_legacy.py @@ -0,0 +1,257 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from ._division import fp8_division +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block + +"""GELU Activation Backward""" +"""Input uses 1 * 16 group quantization""" +"""Grad uses 1 * 16 group quantization""" +"""Output uses per-tensor quantization""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_gelu_backward_legacy_kernel( + output_ptr, + output_scale_ptr, # output + input_ptr, + input_scale_ptr, # input + grad_ptr, + grad_scale_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, # shape + input_stride_0, + input_stride_1, # input stride + s_input_stride_0, + s_input_stride_1, # scale of input stride + grad_stride_0, + grad_stride_1, # input stride + s_grad_stride_0, + s_grad_stride_1, # scale of input stride + output_stride_0, + output_stride_1, # output stride + s_output_stride_0, + s_output_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input_ptr = tl.make_block_ptr( + base=input_scale_ptr, + shape=(M, SN), + strides=(s_input_stride_0, s_input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + scale_input = tl.load(scale_input_ptr) + + input = input.to(tl.float32) + scale_input = scale_input.to(tl.float32) + + # Dequantize and gelu calculation + scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1)) + input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) + input = input * scale_input + + # pointers of gradient + grad_block_ptr = tl.make_block_ptr( + base=grad_ptr, + shape=(M, N), + strides=(grad_stride_0, grad_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # grad ptr + scale_grad_ptr = tl.make_block_ptr( + base=grad_scale_ptr, + shape=(M, SN), + strides=(s_grad_stride_0, s_grad_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + grad = tl.load(grad_block_ptr) + scale_grad = tl.load(scale_grad_ptr) + + grad = grad.to(tl.float32) + scale_grad = scale_grad.to(tl.float32) + + # Dequantize and gelu calculation + scale_grad = tl.reshape(scale_grad, (BLOCK_M, BLOCK_SN, 1)) + grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB)) + grad = grad * scale_grad + + # Actual Calculation of GELU's backward + pi = float(torch.pi) + cdf = (1.0 + tl.math.erf(input / tl.math.sqrt(2.0))) / 2 + exp = input * tl.exp(-libdevice.pow(input, 2) / 2) / tl.sqrt(2 * pi) + gelu_output = cdf + exp + + gelu_output = gelu_output * grad + + # Quantize Scale calculation + abs_output = tl.abs(gelu_output) + max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1)) + + # Quantize + # gelu_output = tl.fdiv(gelu_output, scale_output) + gelu_output = gelu_output.to(output_ptr.type.element_ty) + + scale_output = scale_output.to(output_scale_ptr.type.element_ty) + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN)) + gelu_output = tl.reshape(gelu_output, (BLOCK_M, BLOCK_N)) + + # debug + # gelu_output = input + # scale_output = scale_input + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(M, N), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output_ptr = tl.make_block_ptr( + base=output_scale_ptr, + shape=(M, SN), + strides=(s_output_stride_0, s_output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + tl.store(output_block_ptr, gelu_output, boundary_check=(0, 1)) + tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1)) + + +def fp8_gelu_backward_legacy(x, s_x, g, s_g, QB): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + assert len(s_x.shape) == 3 + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + s_x = s_x.reshape(-1, s_x.shape[-1]) + g = g.reshape(-1, g.shape[-1]) + s_g = s_g.reshape(-1, s_g.shape[-1]) + + # defining the input and output tensor + M, N = x.shape + _, SN = s_x.shape # assume the shape of quantization block size is always 1 * G + + y = torch.empty_like(g, dtype=torch.bfloat16) + s_y = torch.empty_like(s_g, dtype=s_g.dtype) + fp8MaxValue = FP8_MAX_VALUE[g.dtype] # E4M3 and E5M2 have different max value + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_gelu_backward_legacy_kernel[grid]( + y, + s_y, + x, + s_x, + g, + s_g, + M, + N, + SN, + QB, + fp8MaxValue, + x.stride(0), + x.stride(1), + s_x.stride(0), + s_x.stride(1), + g.stride(0), + g.stride(1), + s_g.stride(0), + s_g.stride(1), + y.stride(0), + y.stride(1), + s_y.stride(0), + s_y.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + ) + + # Per-tensor quantization + s_y_max = s_y.max() + qy, s_y_max = fp8_division(y, QB, g.dtype, s_y_max) + + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + qy = qy.reshape(BS, -1, qy.shape[-1]) + s_y = s_y.reshape(BS, -1, s_y.shape[-1]) + + return qy, s_y_max diff --git a/llava/model/coat/activation/real_quantization/gelu_fwd.py b/llava/model/coat/activation/real_quantization/gelu_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..547c6f8339feef79870ade12a407c852631ae136 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/gelu_fwd.py @@ -0,0 +1,209 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from ._division_transpose import fp8_division_transpose +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block + +"""GELU Activation Forward""" +"""Input uses 1 * 16 group quantization""" +"""Output uses 1 * 16 group quantization""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + +__all__ = ["fp8_gelu_forward"] + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_gelu_forward_kernel( + output_ptr, + output_scale_ptr, # output + input_ptr, + input_scale_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, # shape + input_stride_0, + input_stride_1, # input stride + s_input_stride_0, + s_input_stride_1, # scale of input stride + output_stride_0, + output_stride_1, # output stride + s_output_stride_0, + s_output_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input_ptr = tl.make_block_ptr( + base=input_scale_ptr, + shape=(M, SN), + strides=(s_input_stride_0, s_input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + scale_input = tl.load(scale_input_ptr) + + input = input.to(tl.float32) + scale_input = scale_input.to(tl.float32) + + # Dequantize and gelu calculation + scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1)) + input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) + input = input * scale_input + + # Actual Calculation of GeLU + cdf = (1.0 + tl.math.erf(input / tl.math.sqrt(2.0))) / 2 + gelu_output = cdf * input + + # Quantize Scale calculation + abs_output = tl.abs(gelu_output) + max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1)) + + # Quantize + # gelu_output = tl.fdiv(gelu_output, scale_output) + gelu_output = gelu_output.to(output_ptr.type.element_ty) + + scale_output = scale_output.to(output_scale_ptr.type.element_ty) + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN)) + gelu_output = tl.reshape(gelu_output, (BLOCK_M, BLOCK_N)) + + # debug + # gelu_output = input + # scale_output = scale_input + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(M, N), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output_ptr = tl.make_block_ptr( + base=output_scale_ptr, + shape=(M, SN), + strides=(s_output_stride_0, s_output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + tl.store(output_block_ptr, gelu_output, boundary_check=(0, 1)) + tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1)) + + +def fp8_gelu_forward(x, s_x, QB, transpose_output_2d=False): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + assert len(s_x.shape) == 3 + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + s_x = s_x.reshape(-1, s_x.shape[-1]) + + # defining the input and output tensor + M, N = x.shape + _, SN = s_x.shape # assume the shape of quantization block size is always 1 * G + + y = torch.empty_like(x, dtype=torch.bfloat16) + s_y = torch.empty_like(s_x, dtype=s_x.dtype) + fp8MaxValue = FP8_MAX_VALUE[x.dtype] # E4M3 and E5M2 have different max value + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_gelu_forward_kernel[grid]( + y, + s_y, + x, + s_x, + M, + N, + SN, + QB, + fp8MaxValue, + x.stride(0), + x.stride(1), + s_x.stride(0), + s_x.stride(1), + y.stride(0), + y.stride(1), + s_y.stride(0), + s_y.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + ) + + s_y_max = s_y.max() + qy, s_y_max, qy_t = fp8_division_transpose(y, QB, x.dtype, s_y_max) + + # Recover 2D to 3D + if batched: + qy = qy.reshape(BS, -1, qy.shape[-1]) + s_y = s_y.reshape(BS, -1, s_y.shape[-1]) + if not transpose_output_2d: + qy_t = qy_t.reshape(BS, -1, qy.shape[-1]) + + return qy, s_y_max, qy_t diff --git a/llava/model/coat/activation/real_quantization/linear.py b/llava/model/coat/activation/real_quantization/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..d643f18f51432efc3a2f80c55b73a18e99f0eb75 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/linear.py @@ -0,0 +1,307 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +try: + from ._division import _stochastic_rounding + from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8 +except: + from common import SCALE_MIN_THRES, FP8_MAX_VALUE, convert_str_to_fp8, convert_fp8_to_embit + from COAT.coat.activation.real_quantization._division import _stochastic_rounding + +import os +import time + +"""Linear Layer Forward + Backward""" +"""Input uses per-tensor quantization""" +"""Output is full-precision/BF16 (for FlashAttention) or 1 * 16 quantization (for the rest)""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +def get_configs_io_block(): + configs = [] + for nstages in [3]: + for block_m in [128, 256]: + for block_n in [128, 256]: + for block_k in [128, 256]: + for nwarps in [8]: + configs.append( + triton.Config( + {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k}, + num_stages=nstages, + num_warps=nwarps, + ) + ) + return configs + + +# @triton.autotune( +# configs=get_configs_io_block(), +# key=["M", "N", "K"], +# ) +@triton.jit +def _fp8matmul_kernel( + A, + B, + C, + noise_ptr, # noise for stochastic + M: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, # + stride_am, + stride_ak, # + stride_bk, + stride_bn, # + stride_cm, + stride_cn, ## + Scale_A, + Scale_B, + Scale_C, + stride_scm, + stride_scn, + output_quantize: tl.constexpr, + QB: tl.constexpr, # default to use 1 * 16 quantization + BIAS, + fp8_max: tl.constexpr, + e_bit: tl.constexpr, + m_bit: tl.constexpr, + SCALE_MIN_THRES: tl.constexpr, + STOCHASTIC: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, +): + # matrix multiplication + pid = tl.program_id(0) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in tl.range(0, tl.cdiv(K, BLOCK_K)): + k_remaining = K - k * BLOCK_K + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0) + + acc = tl.dot(a, b, acc) + + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + scale_a = tl.load(Scale_A) + scale_b = tl.load(Scale_B) + scale_ab = scale_a.to(tl.float32) * scale_b.to(tl.float32) + # fp8 dequantize + acc = acc * scale_ab + + if BIAS: + bias = tl.load(BIAS + rbn) + acc = acc + bias + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + + if output_quantize: + acc = tl.reshape(acc, (BLOCK_M, BLOCK_N // QB, QB)) + abs_acc = tl.abs(acc) + acc_max = tl.max(abs_acc, axis=2) + SCALE_MIN_THRES + # tl.device_print("acc_max", acc_max) + acc_scale = acc_max / fp8_max + # tl.device_print("acc_scale", acc_scale) + acc_scale = tl.reshape(acc_scale, (BLOCK_M, BLOCK_N // QB, 1)) + acc = tl.fdiv(acc, acc_scale) + acc = tl.reshape(acc, (BLOCK_M, BLOCK_N)) + + if STOCHASTIC: + noise_block_ptr = noise_ptr + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + noise = tl.load(noise_block_ptr, boundary_check=(0, 1)) + acc = _stochastic_rounding(acc, noise, e_bit, m_bit) + + acc_scale = tl.reshape(acc_scale, (BLOCK_M, BLOCK_N // QB)) + acc_scale = acc_scale.to(Scale_C.type.element_ty) + acc = acc.to(C.dtype.element_ty) + + rsm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rsn = pid_n * BLOCK_N // QB + tl.arange(0, BLOCK_N // QB) + Scale_C = Scale_C + (rsm[:, None] * stride_scm + rsn[None, :] * stride_scn) + + tl.store(C, acc, mask=mask) + tl.store(Scale_C, acc_scale) + + else: + # handles write-back with reduction-splitting + acc = acc.to(C.dtype.element_ty) + tl.store(C, acc, mask=mask) + + +def fp8matmul(a, b, output_quantize, scale_a, scale_b, QB, bias=None, stochastic=False): + # Deal with batched input + if len(a.shape) == 3: + BS, batched = a.shape[0], True + a = a.reshape(-1, a.shape[2]) + else: + batched = False + + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + K, N = b.shape + fp8MaxValue = FP8_MAX_VALUE[a.dtype] # E4M3 and E5M2 have different max value + e_bit, m_bit = convert_fp8_to_embit[a.dtype] + + # Allocates output. + if output_quantize: + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + # c = torch.empty((M, N), device=a.device, dtype=torch.float32) + scale_c = torch.empty((M, N // QB), device=a.device, dtype=torch.bfloat16) + else: + c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16) + scale_c = torch.empty( + (1, 1), device=a.device, dtype=torch.bfloat16 + ) # This line is useless, equivalent to scale_c = None + + if stochastic: + noise = torch.empty_like(c, dtype=torch.float32).uniform_(-0.5, 0.5) + else: + noise = None + + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + _fp8matmul_kernel[grid]( + a, + b, + c, + noise, # + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), # + scale_a, + scale_b, + scale_c, + scale_c.stride(0), + scale_c.stride(1), + output_quantize=output_quantize, + QB=QB, + BIAS=bias, + fp8_max=fp8MaxValue, + e_bit=e_bit, + m_bit=m_bit, + SCALE_MIN_THRES=SCALE_MIN_THRES, + STOCHASTIC=stochastic, + BLOCK_M=128, + BLOCK_N=256, + BLOCK_K=128, + GROUP_M=8, + num_stages=3, + num_warps=8, + ) + # Reshape output to batch + if batched: + c = c.reshape(BS, -1, N) + if output_quantize: + scale_c = scale_c.reshape(BS, -1, N // QB) + return c, scale_c + else: + if output_quantize: + scale_c = scale_c.reshape(M, N // QB) + return c, scale_c + return c + + +def fp8_linear_forward(x, s, w, s_w, output_quantize, QB, bias=None): + assert s.numel() == 1, f"X uses per-tensor quantization in linear forward, but the scale shape is {s.shape}" + assert s_w.numel() == 1, f"W uses per-tensor quantization in linear forward, but the scale shape is {s_w.shape}" + + w_t = w.t() + return fp8matmul(x, w_t, output_quantize, s, s_w, QB, bias) + + +# def fp8_linear_forward(x, s, w, s_w, output_quantize, QB): +# print("you are using the wrong linear function. ") +# w_t = w.t() +# if output_quantize: +# return fp8matmul(x, w_t, True, s, s_w, QB) +# else: +# y = fp8matmul(x, w_t, False, s, s_w, QB) + +# return y + + +def fp8_linear_backward( + x_t, s, g, s_g, g_t, w_t, s_w, QB, bias=None, stochastic=False, dgrad_quantize=False +): # dgrad_quantize=True for backward before flashattention + assert s.numel() == 1, f"X uses per-tensor quantization in linear backward, but the scale shape is {s.shape}" + assert s_g.numel() == 1, f"G uses per-tensor quantization in linear backward, but the scale shape is {s.shape}" + assert s_w.numel() == 1, f"W uses per-tensor quantization in linear backward, but the scale shape is {s_w.shape}" + + batched = False + if len(g.shape) == 3: # others must be of 2D! + batched = True + BS = g.shape[0] + g = g.reshape(-1, g.shape[-1]) + + w_t_t = w_t.t() + x_t_t = x_t.t() + if dgrad_quantize: + y, s_y = fp8matmul(g, w_t_t, True, s_g, s_w, QB, stochastic=stochastic) + else: + y = fp8matmul(g, w_t_t, False, s_g, s_w, QB) + + w_g = fp8matmul(g_t, x_t_t, False, s_g, s, QB) + + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + if dgrad_quantize: + if s_y.numel() > 1: + s_y = s_y.reshape(BS, -1, s_y.shape[-1]) + if dgrad_quantize: + return y, s_y, w_g + else: + return y, w_g diff --git a/llava/model/coat/activation/real_quantization/mul_bwd.py b/llava/model/coat/activation/real_quantization/mul_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..27f8b2839495d6d4f92fa35d537ea75e3e04c0c2 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/mul_bwd.py @@ -0,0 +1,320 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from ._division_transpose import fp8_division_transpose +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block + +"""Element-wise Multiplication Backward""" +"""Input1 (Gate) uses 1 * 16 group quantization""" +"""Input2 (Up) uses 1 * 16 group quantization""" +"""Grad (Down) uses full-precision/BF16""" +"""Output1 (Gate) uses full-precision/BF16""" +"""Output2 (Up) uses per-tensor quantization, we can choose whether it should be quantized inside this function""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_mul_backward_kernel( + output1_ptr, # output + output2_ptr, + output2_scale_ptr, # output + input1_ptr, + input1_scale_ptr, # input + input2_ptr, + input2_scale_ptr, # input + grad_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, + e_bit, + m_bit, # shape + input1_stride_0, + input1_stride_1, # input1 stride + s_input1_stride_0, + s_input1_stride_1, # scale of input1 stride + input2_stride_0, + input2_stride_1, # input2 stride + s_input2_stride_0, + s_input2_stride_1, # scale of input2 stride + grad_stride_0, + grad_stride_1, # input stride + output1_stride_0, + output1_stride_1, # output stride + output2_stride_0, + output2_stride_1, # output stride + s_output2_stride_0, + s_output2_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + STOCHASTIC: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # --- The first input --- + input1_block_ptr = tl.make_block_ptr( + base=input1_ptr, + shape=(M, N), + strides=(input1_stride_0, input1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input1_ptr = tl.make_block_ptr( + base=input1_scale_ptr, + shape=(M, SN), + strides=(s_input1_stride_0, s_input1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input1 = tl.load(input1_block_ptr) + scale_input1 = tl.load(scale_input1_ptr) + + input1 = input1.to(tl.float32) + scale_input1 = scale_input1.to(tl.float32) + + # Dequantize and mul calculation + scale_input1 = tl.reshape(scale_input1, (BLOCK_M, BLOCK_SN, 1)) + input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB)) + input1 = input1 * scale_input1 + + # --- The second input --- + input2_block_ptr = tl.make_block_ptr( + base=input2_ptr, + shape=(M, N), + strides=(input2_stride_0, input2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input2_ptr = tl.make_block_ptr( + base=input2_scale_ptr, + shape=(M, SN), + strides=(s_input2_stride_0, s_input2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input2 = tl.load(input2_block_ptr) + scale_input2 = tl.load(scale_input2_ptr) + + input2 = input2.to(tl.float32) + scale_input2 = scale_input2.to(tl.float32) + + # Dequantize and mul calculation + scale_input2 = tl.reshape(scale_input2, (BLOCK_M, BLOCK_SN, 1)) + input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB)) + input2 = input2 * scale_input2 + + # pointers of gradient + grad_block_ptr = tl.make_block_ptr( + base=grad_ptr, + shape=(M, N), + strides=(grad_stride_0, grad_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + grad = tl.load(grad_block_ptr) + grad = grad.to(tl.float32) + grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB)) + + # Actual Calculation of Mul Backward + grad1 = grad * input2 + grad1 = tl.reshape(grad1, (BLOCK_M, BLOCK_N)) + grad1 = grad1.to(output1_ptr.type.element_ty) + + # pointers + output1_block_ptr = tl.make_block_ptr( + base=output1_ptr, + shape=(M, N), + strides=(output1_stride_0, output1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + tl.store(output1_block_ptr, grad1, boundary_check=(0, 1)) + + # Actual Calculation of Mul Backward + grad2 = grad * input1 + # Quantize the grad 1 - Scale calculation + abs_grad2 = tl.abs(grad2) + max_val = tl.max(abs_grad2, axis=2) + SCALE_MIN_THRES + scale_grad2 = max_val / fp8_max + scale_grad2 = tl.reshape(scale_grad2, (BLOCK_M, BLOCK_SN, 1)) + # Quantize + # grad1 = tl.fdiv(grad1, scale_output) # do not quantize the output due to the data flow + grad2 = grad2.to(output2_ptr.type.element_ty) + scale_grad2 = scale_grad2.to(output2_scale_ptr.type.element_ty) + scale_grad2 = tl.reshape(scale_grad2, (BLOCK_M, BLOCK_SN)) + grad2 = tl.reshape(grad2, (BLOCK_M, BLOCK_N)) + + # pointers + output2_block_ptr = tl.make_block_ptr( + base=output2_ptr, + shape=(M, N), + strides=(output2_stride_0, output2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output2_ptr = tl.make_block_ptr( + base=output2_scale_ptr, + shape=(M, SN), + strides=(s_output2_stride_0, s_output2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + tl.store(output2_block_ptr, grad2, boundary_check=(0, 1)) + tl.store(scale_output2_ptr, scale_grad2, boundary_check=(0, 1)) + + +def fp8_mul_backward( + x1, s_x1, x2, s_x2, g, QB, fp8type, stochastic=False, output_quantized_transpose=False +): # Stochastic Rounding is left outside this function + # Change batched 3D input to 2D + batched = False + if len(x1.shape) == 3: + assert len(s_x1.shape) == 3 + batched = True + BS = x1.shape[0] + x1 = x1.reshape(-1, x1.shape[-1]) + s_x1 = s_x1.reshape(-1, s_x1.shape[-1]) + x2 = x2.reshape(-1, x2.shape[-1]) + s_x2 = s_x2.reshape(-1, s_x2.shape[-1]) + g = g.reshape(-1, g.shape[-1]) + + if stochastic: + noise = torch.empty_like(g, dtype=torch.float32).uniform_(-0.5, 0.5) + else: + noise = None + + # defining the input and output tensor + M, N = x1.shape + _, SN = s_x1.shape # assume the shape of quantization block size is always 1 * G + assert x1.shape == x2.shape + assert s_x1.shape == s_x2.shape + + if isinstance(fp8type, str): + fp8type = convert_str_to_fp8[fp8type] + y1 = torch.empty_like(g, dtype=torch.bfloat16) + y2 = torch.empty_like(g, dtype=torch.bfloat16) + s_y2 = torch.empty_like(s_x1, dtype=torch.bfloat16) + fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value + e_bit, m_bit = convert_fp8_to_embit[fp8type] + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_mul_backward_kernel[grid]( + y1, + y2, + s_y2, + x1, + s_x1, + x2, + s_x2, + g, + M, + N, + SN, + QB, + fp8MaxValue, + e_bit, + m_bit, + x1.stride(0), + x1.stride(1), + s_x1.stride(0), + s_x1.stride(1), + x2.stride(0), + x2.stride(1), + s_x2.stride(0), + s_x2.stride(1), + g.stride(0), + g.stride(1), + y1.stride(0), + y1.stride(1), + y2.stride(0), + y2.stride(1), + s_y2.stride(0), + s_y2.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + STOCHASTIC=stochastic, + ) + + if not output_quantized_transpose: + # Recover 2D to 3D + if batched: + y1 = y1.reshape(BS, -1, y1.shape[-1]) + + y2 = y2.reshape(BS, -1, y2.shape[-1]) + s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1]) + + return y1, (y2, s_y2) + else: + # Per-tensor quantization + s_y2_max = s_y2.max() + qy2, s_y2_max, qy2_t = fp8_division_transpose(y2, QB, fp8type, s_y2_max) + + # Recover 2D to 3D + if batched: + y1 = y1.reshape(BS, -1, y1.shape[-1]) + + y2 = y2.reshape(BS, -1, y2.shape[-1]) + qy2 = qy2.reshape(BS, -1, qy2.shape[-1]) + s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1]) + + return y1, (qy2, s_y2_max, qy2_t) diff --git a/llava/model/coat/activation/real_quantization/mul_bwd_legacy.py b/llava/model/coat/activation/real_quantization/mul_bwd_legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..803789e83cb27330d6195a515805440af45cc528 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/mul_bwd_legacy.py @@ -0,0 +1,374 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from ._division import _stochastic_rounding +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, get_configs_io_block + +"""Element-wise Multiplication Backward""" +"""Input1 (Gate) uses 1 * 16 group quantization""" +"""Input2 (Up) uses 1 * 16 group quantization""" +"""Grad (Down) uses 1 * 16 group quantization""" +"""Output1 (Gate) uses 1 * 16 quantization""" +"""Output2 (Up) uses per-tensor quantization, but should be quantized outside this function""" # Although it is per-tensor quantization, we only apply per-group quantization here, and the reduction should be performed outside this function. +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_mul_backward_legacy_kernel( + output1_ptr, + output1_scale_ptr, # output + output2_ptr, + output2_scale_ptr, # output + input1_ptr, + input1_scale_ptr, # input + input2_ptr, + input2_scale_ptr, # input + grad_ptr, + grad_scale_ptr, # input + noise_ptr, # noise for stochastic + M, + N, + SN, + QB: tl.constexpr, + fp8_max, + e_bit, + m_bit, # shape + input1_stride_0, + input1_stride_1, # input1 stride + s_input1_stride_0, + s_input1_stride_1, # scale of input1 stride + input2_stride_0, + input2_stride_1, # input2 stride + s_input2_stride_0, + s_input2_stride_1, # scale of input2 stride + grad_stride_0, + grad_stride_1, # input stride + s_grad_stride_0, + s_grad_stride_1, # scale of input stride + output1_stride_0, + output1_stride_1, # output stride + s_output1_stride_0, + s_output1_stride_1, # scale of output stride + output2_stride_0, + output2_stride_1, # output stride + s_output2_stride_0, + s_output2_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + STOCHASTIC: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # --- The first input --- + input1_block_ptr = tl.make_block_ptr( + base=input1_ptr, + shape=(M, N), + strides=(input1_stride_0, input1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input1_ptr = tl.make_block_ptr( + base=input1_scale_ptr, + shape=(M, SN), + strides=(s_input1_stride_0, s_input1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input1 = tl.load(input1_block_ptr) + scale_input1 = tl.load(scale_input1_ptr) + + input1 = input1.to(tl.float32) + scale_input1 = scale_input1.to(tl.float32) + + # Dequantize and mul calculation + scale_input1 = tl.reshape(scale_input1, (BLOCK_M, BLOCK_SN, 1)) + input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB)) + input1 = input1 * scale_input1 + + # --- The second input --- + input2_block_ptr = tl.make_block_ptr( + base=input2_ptr, + shape=(M, N), + strides=(input2_stride_0, input2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input2_ptr = tl.make_block_ptr( + base=input2_scale_ptr, + shape=(M, SN), + strides=(s_input2_stride_0, s_input2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input2 = tl.load(input2_block_ptr) + scale_input2 = tl.load(scale_input2_ptr) + + input2 = input2.to(tl.float32) + scale_input2 = scale_input2.to(tl.float32) + + # Dequantize and mul calculation + scale_input2 = tl.reshape(scale_input2, (BLOCK_M, BLOCK_SN, 1)) + input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB)) + input2 = input2 * scale_input2 + + # pointers of gradient + grad_block_ptr = tl.make_block_ptr( + base=grad_ptr, + shape=(M, N), + strides=(grad_stride_0, grad_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # grad ptr + scale_grad_ptr = tl.make_block_ptr( + base=grad_scale_ptr, + shape=(M, SN), + strides=(s_grad_stride_0, s_grad_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + grad = tl.load(grad_block_ptr) + scale_grad = tl.load(scale_grad_ptr) + + grad = grad.to(tl.float32) + scale_grad = scale_grad.to(tl.float32) + + # Dequantize and swish calculation + scale_grad = tl.reshape(scale_grad, (BLOCK_M, BLOCK_SN, 1)) + grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB)) + grad = grad * scale_grad + + # Actual Calculation of Mul Backward + grad1 = grad * input2 + # Quantize the grad 1 - Scale calculation + abs_grad1 = tl.abs(grad1) + max_val = tl.max(abs_grad1, axis=2) + SCALE_MIN_THRES + scale_grad1 = max_val / fp8_max + scale_grad1 = tl.reshape(scale_grad1, (BLOCK_M, BLOCK_SN, 1)) + # Quantize + grad1 = tl.fdiv(grad1, scale_grad1) # do not quantize the output due to the data flow + scale_grad1 = scale_grad1.to(output1_scale_ptr.type.element_ty) + scale_grad1 = tl.reshape(scale_grad1, (BLOCK_M, BLOCK_SN)) + grad1 = tl.reshape(grad1, (BLOCK_M, BLOCK_N)) + + if STOCHASTIC: + # noise_block_ptr = tl.make_block_ptr( + # base=noise_ptr, + # shape=(M, N), + # strides=(input1_stride_0, input1_stride_1), + # offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + # block_shape=(BLOCK_M, BLOCK_N), + # order=(1, 0) + # ) + # noise = tl.load(noise_block_ptr) + + offs_m = pid_dim0 * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_dim1 * BLOCK_N + tl.arange(0, BLOCK_N) + noise_offset = offs_m[:, None] * input1_stride_0 + offs_n[None, :] * input1_stride_1 + noise = tl.rand(0, noise_offset) + + grad1 = _stochastic_rounding(grad1, noise, e_bit, m_bit) + + grad1 = grad1.to(output1_ptr.type.element_ty) + + # pointers + output1_block_ptr = tl.make_block_ptr( + base=output1_ptr, + shape=(M, N), + strides=(output1_stride_0, output1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output1_ptr = tl.make_block_ptr( + base=output1_scale_ptr, + shape=(M, SN), + strides=(s_output1_stride_0, s_output1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + tl.store(output1_block_ptr, grad1, boundary_check=(0, 1)) + tl.store(scale_output1_ptr, scale_grad1, boundary_check=(0, 1)) + + # Actual Calculation of Mul Backward + grad2 = grad * input1 + # Quantize the grad 1 - Scale calculation + abs_grad2 = tl.abs(grad2) + max_val = tl.max(abs_grad2, axis=2) + SCALE_MIN_THRES + scale_grad2 = max_val / fp8_max + scale_grad2 = tl.reshape(scale_grad2, (BLOCK_M, BLOCK_SN, 1)) + # Quantize + # grad1 = tl.fdiv(grad1, scale_output) # do not quantize the output due to the data flow + grad2 = grad2.to(output2_ptr.type.element_ty) + scale_grad2 = scale_grad2.to(output2_scale_ptr.type.element_ty) + scale_grad2 = tl.reshape(scale_grad2, (BLOCK_M, BLOCK_SN)) + grad2 = tl.reshape(grad2, (BLOCK_M, BLOCK_N)) + + # pointers + output2_block_ptr = tl.make_block_ptr( + base=output2_ptr, + shape=(M, N), + strides=(output2_stride_0, output2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output2_ptr = tl.make_block_ptr( + base=output2_scale_ptr, + shape=(M, SN), + strides=(s_output2_stride_0, s_output2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + tl.store(output2_block_ptr, grad2, boundary_check=(0, 1)) + tl.store(scale_output2_ptr, scale_grad2, boundary_check=(0, 1)) + + +def fp8_mul_backward_legacy( + x1, s_x1, x2, s_x2, g, s_g, QB, stochastic=False +): # Stochastic Rounding is left outside this function + # Change batched 3D input to 2D + batched = False + if len(x1.shape) == 3: + assert len(s_x1.shape) == 3 + batched = True + BS = x1.shape[0] + x1 = x1.reshape(-1, x1.shape[-1]) + s_x1 = s_x1.reshape(-1, s_x1.shape[-1]) + x2 = x2.reshape(-1, x2.shape[-1]) + s_x2 = s_x2.reshape(-1, s_x2.shape[-1]) + g = g.reshape(-1, g.shape[-1]) + s_g = s_g.reshape(-1, s_g.shape[-1]) + + if stochastic: + noise = torch.empty_like(g, dtype=torch.float32).uniform_(-0.5, 0.5) + else: + noise = None + + # defining the input and output tensor + M, N = x1.shape + _, SN = s_x1.shape # assume the shape of quantization block size is always 1 * G + assert x1.shape == x2.shape + assert s_x1.shape == s_x2.shape + + y1 = torch.empty_like(g, dtype=g.dtype) + s_y1 = torch.empty_like(s_g, dtype=s_g.dtype) + y2 = torch.empty_like(g, dtype=torch.bfloat16) + s_y2 = torch.empty_like(s_g, dtype=s_g.dtype) + fp8MaxValue = FP8_MAX_VALUE[g.dtype] # E4M3 and E5M2 have different max value + e_bit, m_bit = convert_fp8_to_embit[g.dtype] + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_mul_backward_legacy_kernel[grid]( + y1, + s_y1, + y2, + s_y2, + x1, + s_x1, + x2, + s_x2, + g, + s_g, + noise, + M, + N, + SN, + QB, + fp8MaxValue, + e_bit, + m_bit, + x1.stride(0), + x1.stride(1), + s_x1.stride(0), + s_x1.stride(1), + x2.stride(0), + x2.stride(1), + s_x2.stride(0), + s_x2.stride(1), + g.stride(0), + g.stride(1), + s_g.stride(0), + s_g.stride(1), + y1.stride(0), + y1.stride(1), + s_y1.stride(0), + s_y1.stride(1), + y2.stride(0), + y2.stride(1), + s_y2.stride(0), + s_y2.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + STOCHASTIC=stochastic, + ) + + # Recover 2D to 3D + if batched: + y1 = y1.reshape(BS, -1, y1.shape[-1]) + s_y1 = s_y1.reshape(BS, -1, s_y1.shape[-1]) + + y2 = y2.reshape(BS, -1, y2.shape[-1]) + s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1]) + + return y1, s_y1, y2, s_y2 diff --git a/llava/model/coat/activation/real_quantization/mul_bwd_silu_fwd.py b/llava/model/coat/activation/real_quantization/mul_bwd_silu_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..d23c26f6c86284b8d614ae0f90819bc2e4bc1718 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/mul_bwd_silu_fwd.py @@ -0,0 +1,337 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from ._division_transpose import fp8_division_transpose +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block + +"""Element-wise Multiplication Backward""" +"""Input1 (Gate) uses 1 * 16 group quantization""" +"""Input2 (Up) uses 1 * 16 group quantization""" +"""Grad (Down) uses full-precision/BF16""" +"""Output1 (Gate) uses full-precision/BF16""" +"""Output2 (Up) uses per-tensor quantization, we can choose whether it should be quantized inside this function""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_mul_backward_silu_forward_kernel( + output1_ptr, # output + output2_ptr, + output2_scale_ptr, # output + input1_ptr, + input1_scale_ptr, # input + input2_ptr, + input2_scale_ptr, # input + grad_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, + e_bit, + m_bit, # shape + input1_stride_0, + input1_stride_1, # input1 stride + s_input1_stride_0, + s_input1_stride_1, # scale of input1 stride + input2_stride_0, + input2_stride_1, # input2 stride + s_input2_stride_0, + s_input2_stride_1, # scale of input2 stride + grad_stride_0, + grad_stride_1, # input stride + output1_stride_0, + output1_stride_1, # output stride + output2_stride_0, + output2_stride_1, # output stride + s_output2_stride_0, + s_output2_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + STOCHASTIC: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # --- The first input --- + input1_block_ptr = tl.make_block_ptr( + base=input1_ptr, + shape=(M, N), + strides=(input1_stride_0, input1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input1_ptr = tl.make_block_ptr( + base=input1_scale_ptr, + shape=(M, SN), + strides=(s_input1_stride_0, s_input1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input1 = tl.load(input1_block_ptr) + scale_input1 = tl.load(scale_input1_ptr) + + input1 = input1.to(tl.float32) + scale_input1 = scale_input1.to(tl.float32) + + # Dequantize and mul calculation + scale_input1 = tl.reshape(scale_input1, (BLOCK_M, BLOCK_SN, 1)) + input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB)) + input1 = input1 * scale_input1 + + # --- recompute SiLU --- + # Actual Calculation of SiLU + sigmoid = 1 / (1.0 + libdevice.exp(-input1)) + silu_output = sigmoid * input1 + + # Quantize Scale calculation + abs_output = tl.abs(silu_output) + max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1)) + + # Quantize + silu_output = tl.fdiv(silu_output, scale_output) + silu_output = silu_output.to(input1_ptr.type.element_ty) + + input1 = silu_output * scale_output + + # --- The second input --- + input2_block_ptr = tl.make_block_ptr( + base=input2_ptr, + shape=(M, N), + strides=(input2_stride_0, input2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input2_ptr = tl.make_block_ptr( + base=input2_scale_ptr, + shape=(M, SN), + strides=(s_input2_stride_0, s_input2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input2 = tl.load(input2_block_ptr) + scale_input2 = tl.load(scale_input2_ptr) + + input2 = input2.to(tl.float32) + scale_input2 = scale_input2.to(tl.float32) + + # Dequantize and mul calculation + scale_input2 = tl.reshape(scale_input2, (BLOCK_M, BLOCK_SN, 1)) + input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB)) + input2 = input2 * scale_input2 + + # pointers of gradient + grad_block_ptr = tl.make_block_ptr( + base=grad_ptr, + shape=(M, N), + strides=(grad_stride_0, grad_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + grad = tl.load(grad_block_ptr) + grad = grad.to(tl.float32) + grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB)) + + # Actual Calculation of Mul Backward + grad1 = grad * input2 + grad1 = tl.reshape(grad1, (BLOCK_M, BLOCK_N)) + grad1 = grad1.to(output1_ptr.type.element_ty) + + # pointers + output1_block_ptr = tl.make_block_ptr( + base=output1_ptr, + shape=(M, N), + strides=(output1_stride_0, output1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + tl.store(output1_block_ptr, grad1, boundary_check=(0, 1)) + + # Actual Calculation of Mul Backward + grad2 = grad * input1 + # Quantize the grad 1 - Scale calculation + abs_grad2 = tl.abs(grad2) + max_val = tl.max(abs_grad2, axis=2) + SCALE_MIN_THRES + scale_grad2 = max_val / fp8_max + scale_grad2 = tl.reshape(scale_grad2, (BLOCK_M, BLOCK_SN, 1)) + # Quantize + # grad1 = tl.fdiv(grad1, scale_output) # do not quantize the output due to the data flow + grad2 = grad2.to(output2_ptr.type.element_ty) + scale_grad2 = scale_grad2.to(output2_scale_ptr.type.element_ty) + scale_grad2 = tl.reshape(scale_grad2, (BLOCK_M, BLOCK_SN)) + grad2 = tl.reshape(grad2, (BLOCK_M, BLOCK_N)) + + # pointers + output2_block_ptr = tl.make_block_ptr( + base=output2_ptr, + shape=(M, N), + strides=(output2_stride_0, output2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output2_ptr = tl.make_block_ptr( + base=output2_scale_ptr, + shape=(M, SN), + strides=(s_output2_stride_0, s_output2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + tl.store(output2_block_ptr, grad2, boundary_check=(0, 1)) + tl.store(scale_output2_ptr, scale_grad2, boundary_check=(0, 1)) + + +def fp8_mul_backward_silu_forward( + x1, s_x1, x2, s_x2, g, QB, fp8type, stochastic=False, output_quantized_transpose=False +): # Stochastic Rounding is left outside this function + # Change batched 3D input to 2D + batched = False + if len(x1.shape) == 3: + assert len(s_x1.shape) == 3 + batched = True + BS = x1.shape[0] + x1 = x1.reshape(-1, x1.shape[-1]) + s_x1 = s_x1.reshape(-1, s_x1.shape[-1]) + x2 = x2.reshape(-1, x2.shape[-1]) + s_x2 = s_x2.reshape(-1, s_x2.shape[-1]) + g = g.reshape(-1, g.shape[-1]) + + if stochastic: + noise = torch.empty_like(g, dtype=torch.float32).uniform_(-0.5, 0.5) + else: + noise = None + + # defining the input and output tensor + M, N = x1.shape + _, SN = s_x1.shape # assume the shape of quantization block size is always 1 * G + assert x1.shape == x2.shape + assert s_x1.shape == s_x2.shape + + if isinstance(fp8type, str): + fp8type = convert_str_to_fp8[fp8type] + y1 = torch.empty_like(g, dtype=torch.bfloat16) + y2 = torch.empty_like(g, dtype=torch.bfloat16) + s_y2 = torch.empty_like(s_x1, dtype=torch.bfloat16) + fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value + e_bit, m_bit = convert_fp8_to_embit[fp8type] + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_mul_backward_silu_forward_kernel[grid]( + y1, + y2, + s_y2, + x1, + s_x1, + x2, + s_x2, + g, + M, + N, + SN, + QB, + fp8MaxValue, + e_bit, + m_bit, + x1.stride(0), + x1.stride(1), + s_x1.stride(0), + s_x1.stride(1), + x2.stride(0), + x2.stride(1), + s_x2.stride(0), + s_x2.stride(1), + g.stride(0), + g.stride(1), + y1.stride(0), + y1.stride(1), + y2.stride(0), + y2.stride(1), + s_y2.stride(0), + s_y2.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + STOCHASTIC=stochastic, + ) + + if not output_quantized_transpose: + # Recover 2D to 3D + if batched: + y1 = y1.reshape(BS, -1, y1.shape[-1]) + + y2 = y2.reshape(BS, -1, y2.shape[-1]) + s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1]) + + return y1, (y2, s_y2) + else: + # Per-tensor quantization + s_y2_max = s_y2.max() + qy2, s_y2_max, qy2_t = fp8_division_transpose(y2, QB, fp8type, s_y2_max) + + # Recover 2D to 3D + if batched: + y1 = y1.reshape(BS, -1, y1.shape[-1]) + + y2 = y2.reshape(BS, -1, y2.shape[-1]) + qy2 = qy2.reshape(BS, -1, qy2.shape[-1]) + s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1]) + + return y1, (qy2, s_y2_max, qy2_t) diff --git a/llava/model/coat/activation/real_quantization/mul_fwd.py b/llava/model/coat/activation/real_quantization/mul_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..0037c7220c6f1bb0b7b9b5946a7d2553b44c7426 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/mul_fwd.py @@ -0,0 +1,260 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from ._division_transpose import fp8_division_transpose +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block + +"""Element-wise Multiplication Forward""" +"""Input1 (Gate) uses 1 * 16 group quantization""" +"""Input2 (Up) uses 1 * 16 group quantization""" +"""Output uses per-tensor quantization""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + +fp8_max_value = { + torch.float8_e4m3fn: 448, + torch.float8_e5m2: 57344, +} + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def fp8_mul_forward_kernel( + output_ptr, + output_scale_ptr, # output + input1_ptr, + input1_scale_ptr, # input + input2_ptr, + input2_scale_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, # shape + input1_stride_0, + input1_stride_1, # input1 stride + s_input1_stride_0, + s_input1_stride_1, # scale of input1 stride + input2_stride_0, + input2_stride_1, # input2 stride + s_input2_stride_0, + s_input2_stride_1, # scale of input2 stride + output_stride_0, + output_stride_1, # output stride + s_output_stride_0, + s_output_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # --- The first input --- + input1_block_ptr = tl.make_block_ptr( + base=input1_ptr, + shape=(M, N), + strides=(input1_stride_0, input1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input1_ptr = tl.make_block_ptr( + base=input1_scale_ptr, + shape=(M, SN), + strides=(s_input1_stride_0, s_input1_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input1 = tl.load(input1_block_ptr) + scale_input1 = tl.load(scale_input1_ptr) + + input1 = input1.to(tl.float32) + scale_input1 = scale_input1.to(tl.float32) + + # Dequantize and mul calculation + scale_input1 = tl.reshape(scale_input1, (BLOCK_M, BLOCK_SN, 1)) + input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB)) + input1 = input1 * scale_input1 + + # --- The second input --- + input2_block_ptr = tl.make_block_ptr( + base=input2_ptr, + shape=(M, N), + strides=(input2_stride_0, input2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input2_ptr = tl.make_block_ptr( + base=input2_scale_ptr, + shape=(M, SN), + strides=(s_input2_stride_0, s_input2_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input2 = tl.load(input2_block_ptr) + scale_input2 = tl.load(scale_input2_ptr) + + input2 = input2.to(tl.float32) + scale_input2 = scale_input2.to(tl.float32) + + # Dequantize and mul calculation + scale_input2 = tl.reshape(scale_input2, (BLOCK_M, BLOCK_SN, 1)) + input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB)) + input2 = input2 * scale_input2 + + # Actual Calculation of SiLU + mul_output = input1 * input2 + + # Quantize Scale calculation + abs_output = tl.abs(mul_output) + max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1)) + + # Quantize + # mul_output = tl.fdiv(mul_output, scale_output) # do not quantize the output since it should use per-tensor quantization afterwards + mul_output = mul_output.to(output_ptr.type.element_ty) + scale_output = scale_output.to(output_scale_ptr.type.element_ty) + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN)) + mul_output = tl.reshape(mul_output, (BLOCK_M, BLOCK_N)) + + # debug + # mul_output = input + # scale_output = scale_input + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(M, N), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output_ptr = tl.make_block_ptr( + base=output_scale_ptr, + shape=(M, SN), + strides=(s_output_stride_0, s_output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + tl.store(output_block_ptr, mul_output, boundary_check=(0, 1)) + tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1)) + + +def fp8_mul_forward(x1, s_x1, x2, s_x2, QB, transpose_output_2d=False): + # Change batched 3D input to 2D + batched = False + if len(x1.shape) == 3: + assert len(s_x1.shape) == 3 + batched = True + BS = x1.shape[0] + x1 = x1.reshape(-1, x1.shape[-1]) + s_x1 = s_x1.reshape(-1, s_x1.shape[-1]) + x2 = x2.reshape(-1, x2.shape[-1]) + s_x2 = s_x2.reshape(-1, s_x2.shape[-1]) + + # defining the input and output tensor + M, N = x1.shape + _, SN = s_x1.shape # assume the shape of quantization block size is always 1 * G + assert x1.shape == x2.shape + assert s_x1.shape == s_x2.shape + + y = torch.empty_like(x1, dtype=torch.bfloat16) + s_y = torch.empty_like(s_x1, dtype=s_x1.dtype) + fp8MaxValue = fp8_max_value[x1.dtype] # E4M3 and E5M2 have different max value + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + fp8_mul_forward_kernel[grid]( + y, + s_y, + x1, + s_x1, + x2, + s_x2, + M, + N, + SN, + QB, + fp8MaxValue, + x1.stride(0), + x1.stride(1), + s_x1.stride(0), + s_x1.stride(1), + x2.stride(0), + x2.stride(1), + s_x2.stride(0), + s_x2.stride(1), + y.stride(0), + y.stride(1), + s_y.stride(0), + s_y.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + ) + + s_y_max = s_y.max() + qy, s_y_max, qy_t = fp8_division_transpose(y, QB, x2.dtype, s_y_max) + qy = qy.to(x2.dtype) + qy_t = qy_t.to(x2.dtype) + + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + qy = qy.reshape(BS, -1, qy.shape[-1]) + if not transpose_output_2d: + qy_t = qy_t.reshape(BS, -1, qy.shape[-1]) + + return qy, s_y_max, qy_t diff --git a/llava/model/coat/activation/real_quantization/silu_bwd.py b/llava/model/coat/activation/real_quantization/silu_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..052bf66e6af6f5f3ce40e08e7b5c8c92e97bfbe3 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/silu_bwd.py @@ -0,0 +1,255 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from ._division_transpose import fp8_division_transpose +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_str_to_fp8, get_configs_io_block + +"""SiLU Activation Backward""" +"""Input uses 1 * 16 group quantization""" +"""Grad uses full-precision / BF16""" +"""Output uses per-tensor quantization, we can choose whether it should be quantized inside this function""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_silu_backward_kernel( + output_ptr, + output_scale_ptr, # output + input_ptr, + input_scale_ptr, # input + grad_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, # shape + input_stride_0, + input_stride_1, # input stride + s_input_stride_0, + s_input_stride_1, # scale of input stride + grad_stride_0, + grad_stride_1, # input stride + output_stride_0, + output_stride_1, # output stride + s_output_stride_0, + s_output_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input_ptr = tl.make_block_ptr( + base=input_scale_ptr, + shape=(M, SN), + strides=(s_input_stride_0, s_input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + scale_input = tl.load(scale_input_ptr) + + input = input.to(tl.float32) + scale_input = scale_input.to(tl.float32) + + # Dequantize and silu calculation + scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1)) + input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) + input = input * scale_input + + # pointers of gradient + grad_block_ptr = tl.make_block_ptr( + base=grad_ptr, + shape=(M, N), + strides=(grad_stride_0, grad_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + grad = tl.load(grad_block_ptr) + grad = grad.to(tl.float32) + grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB)) + + # Actual Calculation of SiLU's backward + sigmoid = 1 / (1.0 + libdevice.exp(-input)) + silu_output = sigmoid + input * sigmoid * (1 - sigmoid) + silu_output = silu_output * grad + + # Quantize Scale calculation + abs_output = tl.abs(silu_output) + max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1)) + + # Quantize + # silu_output = tl.fdiv(silu_output, scale_output) + silu_output = silu_output.to(output_ptr.type.element_ty) + + scale_output = scale_output.to(output_scale_ptr.type.element_ty) + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN)) + silu_output = tl.reshape(silu_output, (BLOCK_M, BLOCK_N)) + + # debug + # silu_output = input + # scale_output = scale_input + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(M, N), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output_ptr = tl.make_block_ptr( + base=output_scale_ptr, + shape=(M, SN), + strides=(s_output_stride_0, s_output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + tl.store(output_block_ptr, silu_output, boundary_check=(0, 1)) + tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1)) + + +def fp8_silu_backward( + x, s_x, g, QB, fp8type, stochastic=False, output_quantized_transpose=False +): # Stochastic Rounding is left outside this function + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + assert len(s_x.shape) == 3 + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + s_x = s_x.reshape(-1, s_x.shape[-1]) + g = g.reshape(-1, g.shape[-1]) + + # defining the input and output tensor + M, N = x.shape + _, SN = s_x.shape # assume the shape of quantization block size is always 1 * G + + if isinstance(fp8type, str): + fp8type = convert_str_to_fp8[fp8type] + + y = torch.empty_like(g, dtype=torch.bfloat16) + s_y = torch.empty_like(s_x, dtype=torch.bfloat16) + fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_silu_backward_kernel[grid]( + y, + s_y, + x, + s_x, + g, + M, + N, + SN, + QB, + fp8MaxValue, + x.stride(0), + x.stride(1), + s_x.stride(0), + s_x.stride(1), + g.stride(0), + g.stride(1), + y.stride(0), + y.stride(1), + s_y.stride(0), + s_y.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + ) + + if not output_quantized_transpose: + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + s_y = s_y.reshape(BS, -1, s_y.shape[-1]) + + return y, s_y + else: + # Per-tensor quantization + s_y_max = s_y.max() + qy, s_y_max, qy_t = fp8_division_transpose(y, QB, fp8type, s_y_max) + + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + qy = qy.reshape(BS, -1, qy.shape[-1]) + s_y = s_y.reshape(BS, -1, s_y.shape[-1]) + + return qy, s_y_max, qy_t + + # # Per-tensor quantization + # s_y_max = s_y.max() + # qy, s_y_max = fp8_division(y, QB, fp8type, s_y_max) + + # # Recover 2D to 3D + # if batched: + # y = y.reshape(BS, -1, y.shape[-1]) + # qy = qy.reshape(BS, -1, qy.shape[-1]) + # s_y = s_y.reshape(BS, -1, s_y.shape[-1]) + + # return qy, s_y_max diff --git a/llava/model/coat/activation/real_quantization/silu_bwd_legacy.py b/llava/model/coat/activation/real_quantization/silu_bwd_legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..88de423ad973f34718f8d44e1d2e37b3113e1737 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/silu_bwd_legacy.py @@ -0,0 +1,248 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block + +"""SiLU Activation Backward""" +"""Input uses 1 * 16 group quantization""" +"""Grad uses 1 * 16 group quantization""" +"""Output uses per-tensor quantization, but should be quantized outside this function""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_silu_backward_legacy_kernel( + output_ptr, + output_scale_ptr, # output + input_ptr, + input_scale_ptr, # input + grad_ptr, + grad_scale_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, # shape + input_stride_0, + input_stride_1, # input stride + s_input_stride_0, + s_input_stride_1, # scale of input stride + grad_stride_0, + grad_stride_1, # input stride + s_grad_stride_0, + s_grad_stride_1, # scale of input stride + output_stride_0, + output_stride_1, # output stride + s_output_stride_0, + s_output_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input_ptr = tl.make_block_ptr( + base=input_scale_ptr, + shape=(M, SN), + strides=(s_input_stride_0, s_input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + scale_input = tl.load(scale_input_ptr) + + input = input.to(tl.float32) + scale_input = scale_input.to(tl.float32) + + # Dequantize and silu calculation + scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1)) + input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) + input = input * scale_input + + # pointers of gradient + grad_block_ptr = tl.make_block_ptr( + base=grad_ptr, + shape=(M, N), + strides=(grad_stride_0, grad_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # grad ptr + scale_grad_ptr = tl.make_block_ptr( + base=grad_scale_ptr, + shape=(M, SN), + strides=(s_grad_stride_0, s_grad_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + grad = tl.load(grad_block_ptr) + scale_grad = tl.load(scale_grad_ptr) + + grad = grad.to(tl.float32) + scale_grad = scale_grad.to(tl.float32) + + # Dequantize and silu calculation + scale_grad = tl.reshape(scale_grad, (BLOCK_M, BLOCK_SN, 1)) + grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB)) + grad = grad * scale_grad + + # Actual Calculation of SiLU's backward + sigmoid = 1 / (1.0 + libdevice.exp(-input)) + silu_output = sigmoid + input * sigmoid * (1 - sigmoid) + silu_output = silu_output * grad + + # Quantize Scale calculation + abs_output = tl.abs(silu_output) + max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1)) + + # Quantize + # silu_output = tl.fdiv(silu_output, scale_output) + silu_output = silu_output.to(output_ptr.type.element_ty) + + scale_output = scale_output.to(output_scale_ptr.type.element_ty) + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN)) + silu_output = tl.reshape(silu_output, (BLOCK_M, BLOCK_N)) + + # debug + # silu_output = input + # scale_output = scale_input + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(M, N), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output_ptr = tl.make_block_ptr( + base=output_scale_ptr, + shape=(M, SN), + strides=(s_output_stride_0, s_output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + tl.store(output_block_ptr, silu_output, boundary_check=(0, 1)) + tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1)) + + +def fp8_silu_backward_legacy(x, s_x, g, s_g, QB, stochastic=False): # Stochastic Rounding is left outside this function + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + assert len(s_x.shape) == 3 + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + s_x = s_x.reshape(-1, s_x.shape[-1]) + g = g.reshape(-1, g.shape[-1]) + s_g = s_g.reshape(-1, s_g.shape[-1]) + + # defining the input and output tensor + M, N = x.shape + _, SN = s_x.shape # assume the shape of quantization block size is always 1 * G + + y = torch.empty_like(g, dtype=torch.bfloat16) + s_y = torch.empty_like(s_g, dtype=s_g.dtype) + fp8MaxValue = FP8_MAX_VALUE[g.dtype] # E4M3 and E5M2 have different max value + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_silu_backward_legacy_kernel[grid]( + y, + s_y, + x, + s_x, + g, + s_g, + M, + N, + SN, + QB, + fp8MaxValue, + x.stride(0), + x.stride(1), + s_x.stride(0), + s_x.stride(1), + g.stride(0), + g.stride(1), + s_g.stride(0), + s_g.stride(1), + y.stride(0), + y.stride(1), + s_y.stride(0), + s_y.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + ) + + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + s_y = s_y.reshape(BS, -1, s_y.shape[-1]) + + return y, s_y diff --git a/llava/model/coat/activation/real_quantization/silu_fwd.py b/llava/model/coat/activation/real_quantization/silu_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..1e6dadc42b2e25ad548ec3e2f30cc146b8ce01d0 --- /dev/null +++ b/llava/model/coat/activation/real_quantization/silu_fwd.py @@ -0,0 +1,204 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# 4 block +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice + +try: + from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block +except: + from common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block + +"""SiLU Activation Forward""" +"""Input uses 1 * 16 group quantization""" +"""Output uses 1 * 16 group quantization""" +"""The input can be 2D or 3D, but the calculation is performed in 2D""" + + +@triton.autotune( + configs=[] + get_configs_io_block(), + key=[ + "N", + ], +) +@triton.heuristics( + { + "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"], + } +) +@triton.jit +def _fp8_silu_forward_kernel( + output_ptr, + output_scale_ptr, # output + input_ptr, + input_scale_ptr, # input + M, + N, + SN, + QB: tl.constexpr, + fp8_max, # shape + input_stride_0, + input_stride_1, # input stride + s_input_stride_0, + s_input_stride_1, # scale of input stride + output_stride_0, + output_stride_1, # output stride + s_output_stride_0, + s_output_stride_1, # scale of output stride + SCALE_MIN_THRES: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_SN: tl.constexpr, +): # CUDA block size + + # Block PID + pid = tl.program_id(0) + NUM_BLOCK_N = tl.cdiv(N, BLOCK_N) + pid_dim0 = pid // NUM_BLOCK_N + pid_dim1 = pid % NUM_BLOCK_N + + # pointers + input_block_ptr = tl.make_block_ptr( + base=input_ptr, + shape=(M, N), + strides=(input_stride_0, input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + + # input ptr + scale_input_ptr = tl.make_block_ptr( + base=input_scale_ptr, + shape=(M, SN), + strides=(s_input_stride_0, s_input_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + input = tl.load(input_block_ptr) + scale_input = tl.load(scale_input_ptr) + + input = input.to(tl.float32) + scale_input = scale_input.to(tl.float32) + + # Dequantize and silu calculation + scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1)) + input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB)) + input = input * scale_input + + # Actual Calculation of SiLU + sigmoid = 1 / (1.0 + libdevice.exp(-input)) + silu_output = sigmoid * input + + # Quantize Scale calculation + abs_output = tl.abs(silu_output) + max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES + scale_output = max_val / fp8_max + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1)) + + # Quantize + silu_output = tl.fdiv(silu_output, scale_output) + silu_output = silu_output.to(output_ptr.type.element_ty) + + scale_output = scale_output.to(output_scale_ptr.type.element_ty) + scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN)) + silu_output = tl.reshape(silu_output, (BLOCK_M, BLOCK_N)) + + # debug + # silu_output = input + # scale_output = scale_input + + # pointers + output_block_ptr = tl.make_block_ptr( + base=output_ptr, + shape=(M, N), + strides=(output_stride_0, output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + scale_output_ptr = tl.make_block_ptr( + base=output_scale_ptr, + shape=(M, SN), + strides=(s_output_stride_0, s_output_stride_1), + offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN), + block_shape=(BLOCK_M, BLOCK_SN), + order=(1, 0), + ) + + tl.store(output_block_ptr, silu_output, boundary_check=(0, 1)) + tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1)) + + +def fp8_silu_forward(x, s_x, QB): + # Change batched 3D input to 2D + batched = False + if len(x.shape) == 3: + assert len(s_x.shape) == 3 + batched = True + BS = x.shape[0] + x = x.reshape(-1, x.shape[-1]) + s_x = s_x.reshape(-1, s_x.shape[-1]) + + # defining the input and output tensor + M, N = x.shape + _, SN = s_x.shape # assume the shape of quantization block size is always 1 * G + + y = torch.empty_like(x, dtype=x.dtype) + s_y = torch.empty_like(s_x, dtype=s_x.dtype) + fp8MaxValue = FP8_MAX_VALUE[x.dtype] # E4M3 and E5M2 have different max value + + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) + + _fp8_silu_forward_kernel[grid]( + y, + s_y, + x, + s_x, + M, + N, + SN, + QB, + fp8MaxValue, + x.stride(0), + x.stride(1), + s_x.stride(0), + s_x.stride(1), + y.stride(0), + y.stride(1), + s_y.stride(0), + s_y.stride(1), + SCALE_MIN_THRES=SCALE_MIN_THRES, + ) + + # Recover 2D to 3D + if batched: + y = y.reshape(BS, -1, y.shape[-1]) + s_y = s_y.reshape(BS, -1, s_y.shape[-1]) + + return y, s_y diff --git a/llava/model/coat/activation/utils.py b/llava/model/coat/activation/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b14b79e29cac93f6049c6fa81ff537bd2a5cb8fd --- /dev/null +++ b/llava/model/coat/activation/utils.py @@ -0,0 +1,33 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import os +from collections import defaultdict + +import torch + + +def quant_get_local_rank() -> int: + return int(os.environ.get("LOCAL_RANK") or 0) + + +record_memory_allocated = defaultdict(list) diff --git a/llava/model/coat/fp8_trainer.py b/llava/model/coat/fp8_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..7b9d68029db9fc7fbbd6af01b34d2cf342b4294f --- /dev/null +++ b/llava/model/coat/fp8_trainer.py @@ -0,0 +1,626 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# In this file I only add the logic in line 411 to 415. The rest remains unchanged compared with the original Trainer Class. + +import math +import os +import shutil +import time +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from packaging import version +from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler +from transformers import Trainer +from transformers.debug_utils import DebugOption, DebugUnderflowOverflow +from transformers.integrations import hp_params +from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available +from transformers.integrations.tpu import tpu_spmd_dataloader +from transformers.trainer_callback import ExportableState, TrainerState +from transformers.trainer_pt_utils import get_model_param_count +from transformers.trainer_utils import HPSearchBackend, TrainOutput, has_length, speed_metrics +from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments +from transformers.utils import ( + is_accelerate_available, + is_apex_available, + is_bitsandbytes_available, + is_datasets_available, + is_galore_torch_available, + is_grokadamw_available, + is_in_notebook, + is_ipex_available, + is_liger_kernel_available, + is_lomo_available, + is_peft_available, + is_safetensors_available, + is_sagemaker_dp_enabled, + is_sagemaker_mp_enabled, + is_schedulefree_available, + is_torch_compile_available, + is_torch_mlu_available, + is_torch_mps_available, + is_torch_musa_available, + is_torch_neuroncore_available, + is_torch_npu_available, + is_torch_xla_available, + is_torch_xpu_available, + is_torchao_available, + logging, +) + +if is_apex_available(): + from apex import amp + +if is_datasets_available(): + import datasets + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + import torch_xla.debug.metrics as met + from torch_xla import __version__ as XLA_VERSION + + IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION) + if IS_XLA_FSDPV2_POST_2_2: + import torch_xla.distributed.spmd as xs + import torch_xla.runtime as xr +else: + IS_XLA_FSDPV2_POST_2_2 = False + + +if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + from smdistributed.modelparallel import __version__ as SMP_VERSION + + IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") + + from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat +else: + IS_SAGEMAKER_MP_POST_1_10 = False + + +if is_safetensors_available(): + import safetensors.torch + +if is_peft_available(): + from peft import PeftModel + + +if is_accelerate_available(): + from accelerate import Accelerator + from accelerate import __version__ as accelerate_version + from accelerate import skip_first_batches + from accelerate.utils import ( + DistributedDataParallelKwargs, + DistributedType, + GradientAccumulationPlugin, + load_fsdp_model, + load_fsdp_optimizer, + save_fsdp_model, + save_fsdp_optimizer, + ) + + DATA_SAMPLERS = [RandomSampler] + if version.parse(accelerate_version) > version.parse("0.23.0"): + from accelerate.data_loader import SeedableRandomSampler + + DATA_SAMPLERS += [SeedableRandomSampler] + + if is_deepspeed_available(): + from accelerate.utils import DeepSpeedSchedulerWrapper + +if is_accelerate_available("0.28.0"): + from accelerate.utils import DataLoaderConfiguration + +from .activation.models._fp8manager import FP8Manager + +logger = logging.get_logger(__name__) + + +class CoatFP8Trainer(Trainer): + def _inner_training_loop( + self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None + ): + self.accelerator.free_memory() + self._train_batch_size = batch_size + if self.args.auto_find_batch_size: + if self.state.train_batch_size != self._train_batch_size: + from accelerate.utils import release_memory + + (self.model_wrapped,) = release_memory(self.model_wrapped) + self.model_wrapped = self.model + + # Check for DeepSpeed *after* the intial pass and modify the config + if self.is_deepspeed_enabled: + # Temporarily unset `self.args.train_batch_size` + original_bs = self.args.per_device_train_batch_size + self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu) + self.propagate_args_to_deepspeed(True) + self.args.per_device_train_batch_size = original_bs + self.state.train_batch_size = self._train_batch_size + logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") + # Data loader and number of training steps + train_dataloader = self.get_train_dataloader() + if self.is_fsdp_xla_v2_enabled: + train_dataloader = tpu_spmd_dataloader(train_dataloader) + + # Setting up training control variables: + # number of training epochs: num_train_epochs + # number of training steps per epoch: num_update_steps_per_epoch + # total number of training steps to execute: max_steps + total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size + + len_dataloader = None + num_train_tokens = None + if has_length(train_dataloader): + len_dataloader = len(train_dataloader) + num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps + num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) + num_examples = self.num_examples(train_dataloader) + if args.max_steps > 0: + max_steps = args.max_steps + num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( + args.max_steps % num_update_steps_per_epoch > 0 + ) + # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's + # the best we can do. + num_train_samples = args.max_steps * total_train_batch_size + if args.include_tokens_per_second: + num_train_tokens = ( + self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps + ) + else: + max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) + num_train_epochs = math.ceil(args.num_train_epochs) + num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs + if args.include_tokens_per_second: + num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs + elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size + max_steps = args.max_steps + # Setting a very large number of epochs so we go as many times as necessary over the iterator. + num_train_epochs = sys.maxsize + num_update_steps_per_epoch = max_steps + num_examples = total_train_batch_size * args.max_steps + num_train_samples = args.max_steps * total_train_batch_size + if args.include_tokens_per_second: + num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps + else: + raise ValueError( + "args.max_steps must be set to a positive value if dataloader does not have a length, was" + f" {args.max_steps}" + ) + + if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: + if self.args.n_gpu > 1: + # nn.DataParallel(model) replicates the model, creating new variables and module + # references registered here no longer work on other gpus, breaking the module + raise ValueError( + "Currently --debug underflow_overflow is not supported under DP. Please use DDP" + " (torchrun or torch.distributed.launch (deprecated))." + ) + else: + debug_overflow = DebugUnderflowOverflow(self.model) # noqa + + delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled + + # We need to reset the scheduler, as its parameters may be different on subsequent calls + if self._created_lr_scheduler: + self.lr_scheduler = None + self._created_lr_scheduler = False + + if self.is_deepspeed_enabled: + self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) + + if not delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + self.state = TrainerState( + stateful_callbacks=[ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ] + ) + self.state.is_hyper_param_search = trial is not None + self.state.train_batch_size = self._train_batch_size + + # Compute absolute values for logging, eval, and save if given as ratio + if args.logging_steps is not None: + if args.logging_steps < 1: + self.state.logging_steps = math.ceil(max_steps * args.logging_steps) + else: + self.state.logging_steps = args.logging_steps + if args.eval_steps is not None: + if args.eval_steps < 1: + self.state.eval_steps = math.ceil(max_steps * args.eval_steps) + else: + self.state.eval_steps = args.eval_steps + if args.save_steps is not None: + if args.save_steps < 1: + self.state.save_steps = math.ceil(max_steps * args.save_steps) + else: + self.state.save_steps = args.save_steps + + # Activate gradient checkpointing if needed + if args.gradient_checkpointing: + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs) + + model = self._wrap_model(self.model_wrapped) + + # as the model is wrapped, don't use `accelerator.prepare` + # this is for unhandled cases such as + # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX + use_accelerator_prepare = True if model is self.model else False + + if delay_optimizer_creation: + if use_accelerator_prepare: + self._fsdp_qlora_plugin_updates() + self.model = self.accelerator.prepare(self.model) + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + # prepare using `accelerator` prepare + if use_accelerator_prepare: + self.model.train() + if hasattr(self.lr_scheduler, "step"): + if self.use_apex: + model = self.accelerator.prepare(self.model) + else: + model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) + else: + # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. + model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.lr_scheduler + ) + elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + # In this case we are in DDP + LOMO, which should be supported + self.optimizer = self.accelerator.prepare(self.optimizer) + + if self.is_fsdp_enabled: + self.model = self.model_wrapped = model + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + + # ckpt loading + if resume_from_checkpoint is not None: + if self.is_deepspeed_enabled: + deepspeed_load_checkpoint( + self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model) + ) + elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled: + self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) + + # Check if saved optimizer or scheduler states exist + self._load_optimizer_and_scheduler(resume_from_checkpoint) + + # important: at this point: + # self.model is the Transformers Model + # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), + # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc. + + # Train! + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_examples:,}") + logger.info(f" Num Epochs = {num_train_epochs:,}") + logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") + if self.args.per_device_train_batch_size != self._train_batch_size: + logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_steps:,}") + logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") + + self.state.epoch = 0 + start_time = time.time() + epochs_trained = 0 + steps_trained_in_current_epoch = 0 + steps_trained_progress_bar = None + + # Check if continuing training from a checkpoint + if resume_from_checkpoint is not None and os.path.isfile( + os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) + ): + self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) + self.compare_trainer_and_checkpoint_args(self.args, self.state) + self._load_callback_state() + epochs_trained = int(self.state.global_step // num_update_steps_per_epoch) + if not args.ignore_data_skip: + steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) + steps_trained_in_current_epoch *= args.gradient_accumulation_steps + else: + steps_trained_in_current_epoch = 0 + + logger.info(" Continuing training from checkpoint, will skip to saved global_step") + logger.info(f" Continuing training from epoch {epochs_trained}") + logger.info(f" Continuing training from global step {self.state.global_step}") + if not args.ignore_data_skip: + logger.info( + f" Will skip the first {epochs_trained} epochs then the first" + f" {steps_trained_in_current_epoch} batches in the first epoch." + ) + + # Update the references + self.callback_handler.model = self.model + self.callback_handler.optimizer = self.optimizer + self.callback_handler.lr_scheduler = self.lr_scheduler + self.callback_handler.train_dataloader = train_dataloader + if self.hp_name is not None and self._trial is not None: + # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial + # parameter to Train when using DDP. + self.state.trial_name = self.hp_name(self._trial) + if trial is not None: + assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial + self.state.trial_params = hp_params(assignments) + else: + self.state.trial_params = None + # This should be the same if the state has been saved but in case the training arguments changed, it's safer + # to set this after the load. + self.state.max_steps = max_steps + self.state.num_train_epochs = num_train_epochs + self.state.is_local_process_zero = self.is_local_process_zero() + self.state.is_world_process_zero = self.is_world_process_zero() + + # tr_loss is a tensor to avoid synchronization of TPUs through .item() + tr_loss = torch.tensor(0.0).to(args.device) + # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses + self._total_loss_scalar = 0.0 + self._globalstep_last_logged = self.state.global_step + model.zero_grad() + grad_norm: Optional[float] = None + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + + if args.eval_on_start: + self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) + + total_batched_samples = 0 + for epoch in range(epochs_trained, num_train_epochs): + epoch_iterator = train_dataloader + if hasattr(epoch_iterator, "set_epoch"): + epoch_iterator.set_epoch(epoch) + + # Reset the past mems state at the beginning of each epoch if necessary. + if args.past_index >= 0: + self._past = None + + steps_in_epoch = ( + len(epoch_iterator) if len_dataloader is not None else args.max_steps * args.gradient_accumulation_steps + ) + self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) + + if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + + rng_to_sync = False + steps_skipped = 0 + if steps_trained_in_current_epoch > 0: + epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) + steps_skipped = steps_trained_in_current_epoch + steps_trained_in_current_epoch = 0 + rng_to_sync = True + + step = -1 + for step, inputs in enumerate(epoch_iterator): + # NOTE: FP8 related + if total_batched_samples % args.gradient_accumulation_steps == 0: + FP8Manager.is_first_microbatch = True + else: + FP8Manager.is_first_microbatch = False + + total_batched_samples += 1 + + if self.args.include_num_input_tokens_seen: + main_input_name = getattr(self.model, "main_input_name", "input_ids") + if main_input_name not in inputs: + logger.warning( + "Tried to track the number of tokens seen, however the current model is " + "not configured properly to know what item is the input. To fix this, add " + "a `main_input_name` attribute to the model class you are using." + ) + else: + self.state.num_input_tokens_seen += ( + torch.sum( + self.accelerator.gather( + torch.tensor( + inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64 + ) + ) + ) + .cpu() + .item() + ) + if rng_to_sync: + self._load_rng_state(resume_from_checkpoint) + rng_to_sync = False + + # Skip past any already trained steps if resuming training + if steps_trained_in_current_epoch > 0: + steps_trained_in_current_epoch -= 1 + if steps_trained_progress_bar is not None: + steps_trained_progress_bar.update(1) + if steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + continue + elif steps_trained_progress_bar is not None: + steps_trained_progress_bar.close() + steps_trained_progress_bar = None + + if step % args.gradient_accumulation_steps == 0: + self.control = self.callback_handler.on_step_begin(args, self.state, self.control) + + with self.accelerator.accumulate(model): + tr_loss_step = self.training_step(model, inputs) + + if ( + args.logging_nan_inf_filter + and not is_torch_xla_available() + and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) + ): + # if loss is nan or inf simply add the average of previous logged losses + tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) + else: + if tr_loss.device != tr_loss_step.device: + raise ValueError( + f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}" + ) + tr_loss += tr_loss_step + + self.current_flos += float(self.floating_point_ops(inputs)) + + is_last_step_and_steps_less_than_grad_acc = ( + steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch + ) + + if ( + total_batched_samples % args.gradient_accumulation_steps == 0 + or + # last step in epoch but step is always smaller than gradient_accumulation_steps + is_last_step_and_steps_less_than_grad_acc + ): + # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered + # in accelerate. So, explicitly enable sync gradients to True in that case. + if is_last_step_and_steps_less_than_grad_acc: + self.accelerator.gradient_state._set_sync_gradients(True) + + # Gradient clipping + if args.max_grad_norm is not None and args.max_grad_norm > 0: + # deepspeed does its own clipping + + if is_sagemaker_mp_enabled() and args.fp16: + _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm) + elif self.use_apex: + # Revert to normal clipping otherwise, handling Apex or full precision + _grad_norm = nn.utils.clip_grad_norm_( + amp.master_params(self.optimizer), + args.max_grad_norm, + ) + else: + _grad_norm = self.accelerator.clip_grad_norm_( + model.parameters(), + args.max_grad_norm, + ) + + if is_accelerate_available() and self.accelerator.distributed_type == DistributedType.DEEPSPEED: + grad_norm = model.get_global_grad_norm() + # In some cases the grad norm may not return a float + if hasattr(grad_norm, "item"): + grad_norm = grad_norm.item() + else: + grad_norm = _grad_norm + + self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) + + self.optimizer.step() + + self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) + + optimizer_was_run = not self.accelerator.optimizer_step_was_skipped + if optimizer_was_run: + # Delay optimizer scheduling until metrics are generated + if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step() + + model.zero_grad() + self.state.global_step += 1 + self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + + self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) + else: + self.control = self.callback_handler.on_substep_end(args, self.state, self.control) + + if self.control.should_epoch_stop or self.control.should_training_stop: + # PyTorch/XLA relies on the data loader to insert the mark_step for + # each step. Since we are breaking the loop early, we need to manually + # insert the mark_step here. + if is_torch_xla_available(): + xm.mark_step() + break + if step < 0: + logger.warning( + "There seems not to be a single sample in your epoch_iterator, stopping training at step" + f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" + f" num_steps ({max_steps}) higher than the number of available samples." + ) + self.control.should_training_stop = True + + self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) + self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) + + if DebugOption.TPU_METRICS_DEBUG in self.args.debug: + if is_torch_xla_available(): + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + else: + logger.warning( + "You enabled PyTorch/XLA debug metrics but you don't have a TPU " + "configured. Check your training configuration if this is unexpected." + ) + if self.control.should_training_stop: + break + + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of training + delattr(self, "_past") + + logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") + if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: + # Wait for everyone to get here so we are sure the model has been saved by process 0. + if is_torch_xla_available(): + xm.rendezvous("load_best_model_at_end") + elif args.parallel_mode == ParallelMode.DISTRIBUTED: + dist.barrier() + elif is_sagemaker_mp_enabled(): + smp.barrier() + + self._load_best_model() + + # add remaining tr_loss + self._total_loss_scalar += tr_loss.item() + effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError + train_loss = self._total_loss_scalar / effective_global_step + + metrics = speed_metrics( + "train", + start_time, + num_samples=num_train_samples, + num_steps=self.state.max_steps, + num_tokens=num_train_tokens, + ) + self.store_flos() + metrics["total_flos"] = self.state.total_flos + metrics["train_loss"] = train_loss + + self.is_in_train = False + + self._memory_tracker.stop_and_update_metrics(metrics) + + self.log(metrics) + + run_dir = self._get_output_dir(trial) + checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) + + # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save. + if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: + for checkpoint in checkpoints_sorted: + if not os.path.samefile(checkpoint, self.state.best_model_checkpoint): + logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") + shutil.rmtree(checkpoint, ignore_errors=True) + + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + + # Wait for the checkpoint to be uploaded. + self._finish_current_push() + + # After training we make sure to retrieve back the original forward pass method + # for the embedding layer by removing the forward post hook. + if self.neftune_noise_alpha is not None: + self._deactivate_neftune(self.model) + + return TrainOutput(self.state.global_step, train_loss, metrics) diff --git a/llava/model/coat/optimizer/fp8_adamw.py b/llava/model/coat/optimizer/fp8_adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..7776ef1539e07773285cf0d220811d6e73070542 --- /dev/null +++ b/llava/model/coat/optimizer/fp8_adamw.py @@ -0,0 +1,515 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from collections import OrderedDict, defaultdict +from copy import deepcopy +from itertools import chain +from typing import Any, DefaultDict, Dict, Hashable, Iterable, List, Optional, Tuple, Union + +import qoptim_cuda +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer +from typing_extensions import ParamSpec, Self, TypeAlias + +StateDict: TypeAlias = Dict[str, Any] + +convert_str_to_fp8 = {"E4M3": torch.float8_e4m3fn, "E5M2": torch.float8_e5m2} + + +class CoatAdamW(Optimizer): + def __init__( + self, + qargs, + params, + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + amsgrad: bool = False, + *, + fused: Optional[bool] = None, + ): + self.qargs = qargs + assert self.qargs.first_order_expansion == self.qargs.second_order_expansion + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + fused=fused, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("amsgrad", False) + fused = group.setdefault("fused", None) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = torch.tensor(step_val, dtype=torch.float32) + + def _init_group( + self, + group, + params_with_grad, + grads, + amsgrad, + use_expansion, + exp_avgs, + scale_exp_avgs, + expand_exp_avgs, + sqrt_minmax_exp_avgs, + exp_avg_sqs, + scale_exp_avg_sqs, + expand_exp_avg_sqs, + sqrt_minmax_exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ): + for p in group["params"]: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError("AdamW does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + + # print(f'Param shape: {p.shape}', file=open('debug.txt', 'a')) + # print(f'Param shape: {p.shape}, {p.device}') + + # State initialization + if len(state) == 0: + # This is because kernel launches are costly on CUDA and XLA. + state["step"] = torch.tensor(0.0) + + # Should be torch.float8_e4m3fn + first_order_dtype = convert_str_to_fp8[self.qargs.first_order_bit] + second_order_dtype = convert_str_to_fp8[self.qargs.second_order_bit] + scale_shape = (p.numel() + self.qargs.qgroup_size - 1) // self.qargs.qgroup_size + + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p, dtype=first_order_dtype, memory_format=torch.preserve_format) + state["scale_exp_avg"] = torch.zeros(scale_shape, device=p.device, dtype=p.dtype) + if use_expansion: + state["expand_exp_avg"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype) + state["sqrt_minmax_exp_avg"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p, dtype=second_order_dtype, memory_format=torch.preserve_format) + state["scale_exp_avg_sq"] = torch.zeros(scale_shape, device=p.device, dtype=p.dtype) + if use_expansion: + state["expand_exp_avg_sq"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype) + state["sqrt_minmax_exp_avg_sq"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros(p, memory_format=torch.preserve_format) + + exp_avgs.append(state["exp_avg"]) + scale_exp_avgs.append(state["scale_exp_avg"]) + if use_expansion: + expand_exp_avgs.append(state["expand_exp_avg"]) + sqrt_minmax_exp_avgs.append(state["sqrt_minmax_exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + scale_exp_avg_sqs.append(state["scale_exp_avg_sq"]) + if use_expansion: + expand_exp_avg_sqs.append(state["expand_exp_avg_sq"]) + sqrt_minmax_exp_avg_sqs.append(state["sqrt_minmax_exp_avg_sq"]) + + if group["amsgrad"]: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + + state_steps.append(state["step"]) + + @torch._disable_dynamo + def load_state_dict(self, state_dict: StateDict) -> None: + r"""Loads the optimizer state. + + Args: + state_dict (dict): optimizer state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # shallow copy, to be consistent with module API + state_dict = state_dict.copy() + + for pre_hook in self._optimizer_load_state_dict_pre_hooks.values(): + hook_result = pre_hook(self, state_dict) + if hook_result is not None: + state_dict = hook_result + + # Validate the state_dict + groups = self.param_groups + + # Deepcopy as we write into saved_groups later to update state + saved_groups = deepcopy(state_dict["param_groups"]) + + if len(groups) != len(saved_groups): + raise ValueError("loaded state dict has a different number of " "parameter groups") + param_lens = (len(g["params"]) for g in groups) + saved_lens = (len(g["params"]) for g in saved_groups) + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): + raise ValueError( + "loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group" + ) + + # Update the state + id_map = dict( + zip( + chain.from_iterable(g["params"] for g in saved_groups), chain.from_iterable(g["params"] for g in groups) + ) + ) + + def _cast(param, value, param_id=None, param_groups=None, key=None): + r"""Make a deep copy of value, casting all tensors to device of param.""" + if isinstance(value, torch.Tensor): + return CoatAdamW._process_value_according_to_param_policy(param, value, param_id, param_groups, key) + elif isinstance(value, dict): + return { + k: _cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items() + } + elif isinstance(value, Iterable): + return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg] + else: + return value + + # Copy state assigned to params (and cast tensors to appropriate types). + # State that is not assigned to params is copied as is (needed for + # backward compatibility). + state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict) + for k, v in state_dict["state"].items(): + if k in id_map: + param = id_map[k] + state[param] = _cast(param, v, param_id=k, param_groups=state_dict["param_groups"]) + else: + state[k] = v + + # Update parameter groups, setting their 'params' value + def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]: + new_group["params"] = group["params"] + return new_group + + param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] + self.__setstate__({"state": state, "param_groups": param_groups}) + + for post_hook in self._optimizer_load_state_dict_post_hooks.values(): + post_hook(self) + + @staticmethod + def _process_value_according_to_param_policy( + param: torch.Tensor, + value: torch.Tensor, + param_id: int, + param_groups: List[Dict[Any, Any]], + key: Hashable = None, + ) -> torch.Tensor: + # Floating-point types are a bit special here. They are the only ones + # that are assumed to always match the type of params. + # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 + # UNLESS fused or capturable, see note [special device hosting for step] + fused = False + capturable = False + assert param_groups is not None + for pg in param_groups: + if param_id in pg["params"]: + fused = pg["fused"] if "fused" in pg else False + capturable = pg["capturable"] if "capturable" in pg else False + break + if key == "step": + if capturable or fused: + return value.to(dtype=torch.float32, device=param.device) + else: + return value + else: + assert value.dtype in [torch.float8_e4m3fn, torch.float8_e5m2, torch.float32] + return value.to(device=param.device) # do not cast optimizer states + # if param.is_floating_point(): + # return value.to(dtype=param.dtype, device=param.device) + # else: + # return value.to(device=param.device) + + @torch.no_grad() + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + scale_exp_avgs = [] + expand_exp_avgs = [] + sqrt_minmax_exp_avgs = [] + exp_avg_sqs = [] + scale_exp_avg_sqs = [] + expand_exp_avg_sqs = [] + sqrt_minmax_exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group["amsgrad"] + use_expansion = self.qargs.first_order_expansion in ["expansion", "true"] + beta1, beta2 = group["betas"] + + self._init_group( + group, + params_with_grad, + grads, + amsgrad, + use_expansion, + exp_avgs, + scale_exp_avgs, + expand_exp_avgs, + sqrt_minmax_exp_avgs, + exp_avg_sqs, + scale_exp_avg_sqs, + expand_exp_avg_sqs, + sqrt_minmax_exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + ) + + Coatadamw( + self.qargs, + params_with_grad, + grads, + exp_avgs, + scale_exp_avgs, + expand_exp_avgs, + sqrt_minmax_exp_avgs, + exp_avg_sqs, + scale_exp_avg_sqs, + expand_exp_avg_sqs, + sqrt_minmax_exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + use_expansion=use_expansion, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + qgroup_size=self.qargs.qgroup_size, + expand_min=self.qargs.expand_min, + fused=group["fused"], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None), + ) + + return loss + + +def Coatadamw( + qargs, + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + scale_exp_avgs: List[Tensor], + expand_exp_avgs: List[Tensor], + sqrt_minmax_exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + scale_exp_avg_sqs: List[Tensor], + expand_exp_avg_sqs: List[Tensor], + sqrt_minmax_exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, + *, + amsgrad: bool, + use_expansion: bool, + beta1: float, + beta2: float, + lr: Union[float, Tensor], + weight_decay: float, + eps: float, + qgroup_size: int, + expand_min: int, +): + r"""Functional API that performs AdamW algorithm computation. + + See :class:`~torch.optim.AdamW` for details. + """ + if not torch._utils.is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") + + func = _single_tensor_Coatadamw + + func( + qargs, + params, + grads, + exp_avgs, + scale_exp_avgs, + expand_exp_avgs, + sqrt_minmax_exp_avgs, + exp_avg_sqs, + scale_exp_avg_sqs, + expand_exp_avg_sqs, + sqrt_minmax_exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + use_expansion=use_expansion, + beta1=beta1, + beta2=beta2, + lr=lr, + weight_decay=weight_decay, + eps=eps, + qgroup_size=qgroup_size, + expand_min=expand_min, + grad_scale=grad_scale, + found_inf=found_inf, + ) + + +def _dispatch_sqrt(x: float): # float annotation is needed because of torchscript type inference + if not torch.jit.is_scripting() and isinstance(x, torch.Tensor): + return x.sqrt() + else: + return sqrt(x) + + +def _single_tensor_Coatadamw( + qargs, + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + scale_exp_avgs: List[Tensor], + expand_exp_avgs: List[Tensor], + sqrt_minmax_exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + scale_exp_avg_sqs: List[Tensor], + expand_exp_avg_sqs: List[Tensor], + sqrt_minmax_exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + amsgrad: bool, + use_expansion: bool, + beta1: float, + beta2: float, + lr: Union[Tensor, float], + weight_decay: float, + eps: float, + qgroup_size: int, + expand_min: int, +): + + assert grad_scale is None and found_inf is None + + if torch.jit.is_scripting(): + # this assert is due to JIT being dumb and not realizing that the ops below + # have overloads to handle both float and Tensor lrs, so we just assert it's + # a float since most people using JIT are using floats + assert isinstance(lr, float) + + for i, param in enumerate(params): + grad = grads[i] + # First order + exp_avg = exp_avgs[i] + scale_exp_avg = scale_exp_avgs[i] + # Second order + exp_avg_sq = exp_avg_sqs[i] + scale_exp_avg_sq = scale_exp_avg_sqs[i] + step_t = state_steps[i] + + # print(len(exp_avg.unique()), len(exp_avg_sq.unique())) + # print(f"{param.shape}, {grad.shape}, {exp_avg.shape}, {exp_avg_sq.shape}", file=open('debug.txt', 'a')) + + # update step + step_t += 1 + step = int(step_t.item()) + + # Perform Optimizer Step + if use_expansion: + expand_exp_avg = expand_exp_avgs[i] + sqrt_minmax_exp_avg = sqrt_minmax_exp_avgs[i] + expand_exp_avg_sq = expand_exp_avg_sqs[i] + sqrt_minmax_exp_avg_sq = sqrt_minmax_exp_avg_sqs[i] + + qoptim_cuda.fp8_adamw_expand_step( + param, + grad, + exp_avg, + scale_exp_avg, + expand_exp_avg, + sqrt_minmax_exp_avg, + exp_avg_sq, + scale_exp_avg_sq, + expand_exp_avg_sq, + sqrt_minmax_exp_avg_sq, + beta1, + beta2, + lr, + weight_decay, + eps, + step, + qgroup_size, + expand_min, + ) + + else: + qoptim_cuda.fp8_adamw_step( + param, + grad, + exp_avg, + scale_exp_avg, + exp_avg_sq, + scale_exp_avg_sq, + beta1, + beta2, + lr, + weight_decay, + eps, + step, + qgroup_size, + ) diff --git a/llava/model/coat/optimizer/kernels/bindings.cpp b/llava/model/coat/optimizer/kernels/bindings.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aa3dbf8b145c263000ce3b7151bb73714c55f655 --- /dev/null +++ b/llava/model/coat/optimizer/kernels/bindings.cpp @@ -0,0 +1,10 @@ +#include + +#include "include/fp8_adamw.h" +#include "include/fp8_adamw_expand.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fp8_adamw_step", &FP8_AdamW, "Update the quantized optimizer states"); + m.def("fp8_adamw_expand_step", &FP8_AdamW_expand, + "Update the quantized optimizer states, use polynomial expander"); +} diff --git a/llava/model/coat/optimizer/kernels/build/lib.linux-x86_64-cpython-310/qoptim_cuda.cpython-310-x86_64-linux-gnu.so b/llava/model/coat/optimizer/kernels/build/lib.linux-x86_64-cpython-310/qoptim_cuda.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..0d6753b8a2e35f20ec2070c6bce2c9e4acefa0d7 --- /dev/null +++ b/llava/model/coat/optimizer/kernels/build/lib.linux-x86_64-cpython-310/qoptim_cuda.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e40fca6d032a0a3094e50a4fdb04e78998819de2652f0a82e49f49225dd46ba9 +size 235960 diff --git a/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/bindings.o b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/bindings.o new file mode 100644 index 0000000000000000000000000000000000000000..be1ed8ce2592d7ce94e249fd52af112a39de1af2 --- /dev/null +++ b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/bindings.o @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:daebebfbb9d1a3dbb5e22acb50af317aa74b437f7aa965bd959b19010c5fd144 +size 244976 diff --git a/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda.o b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda.o new file mode 100644 index 0000000000000000000000000000000000000000..837c2b25b4a1759b529b7e5bf2f74a781a60b73c --- /dev/null +++ b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda.o @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ac9c6832be5dccfce3f25ae735c67459afffd12f8f747b5018b3121da3dfd2d7 +size 215440 diff --git a/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda_kernel.o b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda_kernel.o new file mode 100644 index 0000000000000000000000000000000000000000..b6c5c480b055b909abffc646f69cd563c615cc47 Binary files /dev/null and b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda_kernel.o differ diff --git a/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda.o b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda.o new file mode 100644 index 0000000000000000000000000000000000000000..ec6051ec3846f04b0f2c0f3bbd6abce6a626c3b5 --- /dev/null +++ b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda.o @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:799978e9ba5a24e75b58d0e218cf522874bb983d15fee15adb02d9ac0f2d7a10 +size 216240 diff --git a/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda_kernel.o b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda_kernel.o new file mode 100644 index 0000000000000000000000000000000000000000..b29913e078a46679c6aedca63accff5ace9ca5b6 Binary files /dev/null and b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda_kernel.o differ diff --git a/llava/model/coat/optimizer/kernels/csrc_expand_quantize/makefile b/llava/model/coat/optimizer/kernels/csrc_expand_quantize/makefile new file mode 100644 index 0000000000000000000000000000000000000000..ad4632ffe641f2043c2cbc4966dff66af2807c26 --- /dev/null +++ b/llava/model/coat/optimizer/kernels/csrc_expand_quantize/makefile @@ -0,0 +1,6 @@ +all: + nvcc nvcc_qoptim.cu -o nvcc_qoptim -gencode=arch=compute_90,code=compute_90 +run: + ./nvcc_qoptim +clean: + rm -f nvcc_qoptim diff --git a/llava/model/coat/optimizer/kernels/csrc_expand_quantize/nvcc_qoptim.cu b/llava/model/coat/optimizer/kernels/csrc_expand_quantize/nvcc_qoptim.cu new file mode 100644 index 0000000000000000000000000000000000000000..79f370e76aaf3b99fc53c7fe053731dc9358729f --- /dev/null +++ b/llava/model/coat/optimizer/kernels/csrc_expand_quantize/nvcc_qoptim.cu @@ -0,0 +1,478 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace cg = cooperative_groups; +#define WARPSIZE 32 +#define QGROUPSIZE 128 +#define QUANT_MIN_VAL 1e-20 + +template +inline float fp8_dtype_max(const T &variable) { + if (std::is_same::value) { + return 448; + } else if (std::is_same::value) { + return 57344; + } else { + throw "Unsupported data format"; + } +} + +typedef enum { fp8_adamw } myCsrcKernels; + +void fp8_adamw_cpu(float *params, float *grads, float *fp_exp_avg, + float *fp_exp_avg_sq, float beta1, float beta2, float lr, + float wd, float eps, int step, int qgroup_size, int M, + int N) { + for (int idx = 0; idx < M * N; idx++) { + fp_exp_avg[idx] = beta1 * fp_exp_avg[idx] + (1 - beta1) * grads[idx]; + fp_exp_avg_sq[idx] = + beta2 * fp_exp_avg_sq[idx] + (1 - beta2) * grads[idx] * grads[idx]; + + const float correction1 = 1.0f - powf(beta1, step); + const float correction2_sqrt = sqrtf(1.0f - powf(beta2, step)); + + float denom = + (sqrtf(fp_exp_avg_sq[idx]) / correction2_sqrt + eps) * correction1; + float update = (fp_exp_avg[idx] / denom) + (wd * params[idx]); + params[idx] = params[idx] - (lr * update); + } +} + +template +void printFloatArrayToFile(T *array, int M, int N, + const std::string &outputFileName) { + std::ofstream outputFile(outputFileName); + if (!outputFile.is_open()) { + std::cout << "Failed to open the file." << std::endl; + return; + } + + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + int index = i * N + j; + outputFile << std::setw(10) << std::right << std::fixed + << std::setprecision(6) << (float)array[index] << " "; + if (j == N - 1) { + outputFile << "\n"; + } + } + } +} + +template +__global__ void fp8_adamw_csrc( + scalar_t *__restrict__ params, scalar_t *__restrict__ grads, + __nv_fp8_e4m3 *__restrict__ exp_avg, float *__restrict__ scale_exp_avg, + float *__restrict__ expand_exp_avg, float *__restrict__ sqrtminmax_exp_avg, + __nv_fp8_e4m3 *__restrict__ exp_avg_sq, + float *__restrict__ scale_exp_avg_sq, float *__restrict__ expand_exp_avg_sq, + float *__restrict__ sqrtminmax_exp_avg_sq, float beta1, float beta2, + float lr, float wd, float eps, int step, int qgroup_size, int expand_min, + int total_elements, int total_scale_elements) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int scale_idx = blockIdx.x; + + float float_exp_avg, float_exp_avg_sq; + float correction1, correction2_sqrt; + float denom, update; + + if (idx < total_elements) { + // dequantize the optimizer states + float_exp_avg = float(exp_avg[idx]) * scale_exp_avg[scale_idx]; + int sign_exp_avg = 1 - 2 * signbit(float_exp_avg); + float_exp_avg = sign_exp_avg * + powf(fabsf(float_exp_avg), 1 / expand_exp_avg[scale_idx]) * + sqrtminmax_exp_avg[scale_idx]; + float_exp_avg_sq = float(exp_avg_sq[idx]) * scale_exp_avg_sq[scale_idx]; + float_exp_avg_sq = + powf(float_exp_avg_sq, 1 / expand_exp_avg_sq[scale_idx]) * + sqrtminmax_exp_avg_sq[scale_idx]; + + // calculation of optimizer.step() + float_exp_avg = beta1 * float_exp_avg + (1 - beta1) * grads[idx]; + float_exp_avg_sq = + beta2 * float_exp_avg_sq + (1 - beta2) * grads[idx] * grads[idx]; + + correction1 = 1.0f - powf(beta1, step); + correction2_sqrt = sqrtf(1.0f - powf(beta2, step)); + + denom = (sqrtf(float_exp_avg_sq) / correction2_sqrt + eps) * correction1; + update = (float_exp_avg / denom) + (wd * params[idx]); + params[idx] = params[idx] - (lr * update); + } else { + float_exp_avg = 0.0f; + float_exp_avg_sq = 0.0f; + } + + //// quantize the first-order and second-order momentum + int wid = threadIdx.x / WARPSIZE; + + // reduction within a warp + __shared__ float sharedFirstMaxVal[32]; + __shared__ float sharedFirstMinVal[32]; + __shared__ float sharedSecondMaxVal[32]; + __shared__ float sharedSecondMinVal[32]; + cg::thread_block_tile<32> warpTile = + cg::tiled_partition<32>(cg::this_thread_block()); + float firstMaxVal = fabsf(float_exp_avg); + float firstMinVal = fabsf(float_exp_avg); + float secondMaxVal = fabsf(float_exp_avg_sq); + float secondMinVal = fabsf(float_exp_avg_sq); + // Special Handel + if (idx >= total_elements) { + firstMinVal = __int_as_float(0x7f7fffff); + secondMinVal = __int_as_float(0x7f7fffff); + } + + for (int i = warpTile.size() / 2; i > 0; i /= 2) { + float reduceFirstMaxVal = warpTile.shfl_down(firstMaxVal, i); + float reduceFirstMinVal = warpTile.shfl_down(firstMinVal, i); + float reduceSecondMaxVal = warpTile.shfl_down(secondMaxVal, i); + float reduceSecondMinVal = warpTile.shfl_down(secondMinVal, i); + firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal)); + firstMinVal = fmin(firstMinVal, fabsf(reduceFirstMinVal)); + secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal)); + secondMinVal = fmin(secondMinVal, fabsf(reduceSecondMinVal)); + // printf("First Max: %f\n", reduceFirstMaxVal); + } + int lane = warpTile.thread_rank(); + if (lane == 0) { + sharedFirstMaxVal[wid] = firstMaxVal; + sharedFirstMinVal[wid] = firstMinVal; + sharedSecondMaxVal[wid] = secondMaxVal; + sharedSecondMinVal[wid] = secondMinVal; + } + + __syncthreads(); + + // reduction within a block + __shared__ float shared_absmax_exp_avg; + __shared__ float shared_absmin_exp_avg; + __shared__ float shared_absmax_exp_avg_sq; + __shared__ float shared_absmin_exp_avg_sq; + firstMaxVal = + (threadIdx.x < blockDim.x / warpSize) ? sharedFirstMaxVal[lane] : 0; + firstMinVal = + (threadIdx.x < blockDim.x / warpSize) ? sharedFirstMinVal[lane] : 1e9; + secondMaxVal = + (threadIdx.x < blockDim.x / warpSize) ? sharedSecondMaxVal[lane] : 0; + secondMinVal = + (threadIdx.x < blockDim.x / warpSize) ? sharedSecondMinVal[lane] : 1e9; + if (wid == 0) { + for (int offset = WARPSIZE / 2; offset > 0; offset /= 2) { + float reduceFirstMaxVal = + __shfl_down_sync(0xFFFFFFFF, firstMaxVal, offset); + float reduceFirstMinVal = + __shfl_down_sync(0xFFFFFFFF, firstMinVal, offset); + float reduceSecondMaxVal = + __shfl_down_sync(0xFFFFFFFF, secondMaxVal, offset); + float reduceSecondMinVal = + __shfl_down_sync(0xFFFFFFFF, secondMinVal, offset); + firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal)); + firstMinVal = fmin(firstMinVal, fabsf(reduceFirstMinVal)); + secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal)); + secondMinVal = fmin(secondMinVal, fabsf(reduceSecondMinVal)); + } + if (lane == 0) { + shared_absmax_exp_avg = firstMaxVal; + shared_absmin_exp_avg = firstMinVal; + shared_absmax_exp_avg_sq = secondMaxVal; + shared_absmin_exp_avg_sq = secondMinVal; + } + } + + __syncthreads(); + + if (idx < total_elements) { + // float fp8MaxVal = fp8_dtype_max<__nv_fp8_e4m3>(exp_avg[idx]); + // scaling factor before expanding + float fp8MaxVal = 448; + + // dynamic exponent quantization part + firstMaxVal = shared_absmax_exp_avg + QUANT_MIN_VAL; + firstMinVal = shared_absmin_exp_avg + QUANT_MIN_VAL; + secondMaxVal = shared_absmax_exp_avg_sq + QUANT_MIN_VAL; + secondMinVal = shared_absmin_exp_avg_sq + QUANT_MIN_VAL; + + // calculate the ratio and make the scale to center + float firstRatio = firstMaxVal / firstMinVal; + float secondRatio = secondMaxVal / secondMinVal; + float firstSqrtMinMax = sqrt(firstMaxVal * firstMinVal); + float secondSqrtMinMax = sqrt(secondMaxVal * secondMinVal); + + // printf("Max %f, Min %f, Origin %f \n", firstMaxVal, firstMinVal, + // float_exp_avg); + + // since we use x^k expander, calculate the optimal expanding factor + float ratioUpperBound = fp8MaxVal * fp8MaxVal / 2; + float firstExp = + floor((log2f(ratioUpperBound) / log2f(firstRatio)) * expand_min) / + expand_min; // expand_min is set to 8 for example, then the firstExp is + // the multiple of 1/8 + float secondExp = + floor((log2f(ratioUpperBound) / log2f(secondRatio)) * expand_min) / + expand_min; + + int sign_exp_avg = 1 - 2 * signbit(float_exp_avg); + float_exp_avg = + sign_exp_avg * powf(fabsf(float_exp_avg) / firstSqrtMinMax, firstExp); + float_exp_avg_sq = powf(float_exp_avg_sq / secondSqrtMinMax, secondExp); + + // correspondingly, change the scaling factor + float new_scale_exp_avg = + powf(firstMaxVal / firstSqrtMinMax, firstExp) / fp8MaxVal; + float new_scale_exp_avg_sq = + powf(secondMaxVal / secondSqrtMinMax, secondExp) / fp8MaxVal; + + // quantize the optimizer states + __nv_fp8_e4m3 exp_avg_new = + static_cast<__nv_fp8_e4m3>(float_exp_avg / new_scale_exp_avg); + __nv_fp8_e4m3 exp_avg_sq_new = + static_cast<__nv_fp8_e4m3>(float_exp_avg_sq / new_scale_exp_avg_sq); + // __half exp_avg_new = static_cast<__half>(float_exp_avg / + // new_scale_exp_avg); + // __half exp_avg_sq_new = static_cast<__half>(float_exp_avg_sq / + // new_scale_exp_avg_sq); + + // printf("idx: %d, float: %f, quantize: %f\n", idx, float_exp_avg, + // (float)exp_avg_new * new_scale_exp_avg); + + // store the output + exp_avg[idx] = exp_avg_new; + exp_avg_sq[idx] = exp_avg_sq_new; + scale_exp_avg[scale_idx] = new_scale_exp_avg; + scale_exp_avg_sq[scale_idx] = new_scale_exp_avg_sq; + expand_exp_avg[scale_idx] = firstExp; + expand_exp_avg_sq[scale_idx] = secondExp; + sqrtminmax_exp_avg[scale_idx] = firstSqrtMinMax; + sqrtminmax_exp_avg_sq[scale_idx] = secondSqrtMinMax; + } +} + +template +void myKernelLauncher(float *params, float *grads, __nv_fp8_e4m3 *exp_avg, + float *scale_exp_avg, float *expand_exp_avg, + float *sqrtminmax_exp_avg, __nv_fp8_e4m3 *exp_avg_sq, + float *scale_exp_avg_sq, float *expand_exp_avg_sq, + float *sqrtminmax_exp_avg_sq, float beta1, float beta2, + float lr, float wd, float eps, int step, int qgroup_size, + int expand_min, int M, int N) { + if (algo == fp8_adamw) { + const int block_dim = 128; + int grid_dim = (M * N + qgroup_size - 1) / block_dim; + const dim3 gridDim(grid_dim); + const dim3 blockDim(block_dim); + printf("Yes!\n"); + fp8_adamw_csrc<<>>( + params, grads, exp_avg, scale_exp_avg, expand_exp_avg, + sqrtminmax_exp_avg, exp_avg_sq, scale_exp_avg_sq, expand_exp_avg_sq, + sqrtminmax_exp_avg_sq, beta1, beta2, lr, wd, eps, step, qgroup_size, + expand_min, M * N, int(floor(M * N / 128.))); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) { + std::cout << "CUDA error occurred in kernel launch: " + << cudaGetErrorString(error) << std::endl; + return; + } + printf("Finish!\n"); + } +} + +float testMaxError(void (*myGPUKernel)(float *, float *, __nv_fp8_e4m3 *, + float *, float *, float *, + __nv_fp8_e4m3 *, float *, float *, + float *, // tensor input + float, float, float, float, float, int, + int, int, // float and int input + int, int), // M and N + int M, int N) { + size_t size_param = M * N * sizeof(float); + size_t size_optim = M * N * sizeof(__nv_fp8_e4m3); + size_t size_scale = int(ceil(M * N / 128.)) * sizeof(float); + + // host tensor + float *h_p, *h_g; + __nv_fp8_e4m3 *h_m, *h_v; + float *h_sm, *h_sv; + float *h_fp_m, *h_fp_v; + float *h_cpd_m, *h_cpd_v; + float *h_sqrtmm_m, *h_sqrtmm_v; + + // device tensor + float *d_p, *d_g; + __nv_fp8_e4m3 *d_m, *d_v; + float *d_sm, *d_sv; + float *d_cpd_m, *d_cpd_v; + float *d_sqrtmm_m, *d_sqrtmm_v; + + // device tensor transfer to host + float *hd_p, *hd_g; + __nv_fp8_e4m3 *hd_m, *hd_v; + float *hd_sm, *hd_sv; + float *hd_fp_m, *hd_fp_v; + float *hd_cpd_m, *hd_cpd_v; + float *hd_sqrtmm_m, *hd_sqrtmm_v; + + h_p = (float *)malloc(size_param); + h_g = (float *)malloc(size_param); + h_m = (__nv_fp8_e4m3 *)malloc(size_optim); + h_v = (__nv_fp8_e4m3 *)malloc(size_optim); + h_sm = (float *)malloc(size_scale); + h_sv = (float *)malloc(size_scale); + h_cpd_m = (float *)malloc(size_scale); + h_cpd_v = (float *)malloc(size_scale); + h_sqrtmm_m = (float *)malloc(size_scale); + h_sqrtmm_v = (float *)malloc(size_scale); + h_sv = (float *)malloc(size_scale); + h_fp_m = (float *)malloc(size_param); + h_fp_v = (float *)malloc(size_param); + cudaMalloc(&d_p, size_param); + cudaMalloc(&d_g, size_param); + cudaMalloc(&d_m, size_optim); + cudaMalloc(&d_v, size_optim); + cudaMalloc(&d_sm, size_scale); + cudaMalloc(&d_sv, size_scale); + cudaMalloc(&d_cpd_m, size_scale); + cudaMalloc(&d_cpd_v, size_scale); + cudaMalloc(&d_sqrtmm_m, size_scale); + cudaMalloc(&d_sqrtmm_v, size_scale); + hd_p = (float *)malloc(size_param); + hd_g = (float *)malloc(size_param); + hd_m = (__nv_fp8_e4m3 *)malloc(size_optim); + hd_v = (__nv_fp8_e4m3 *)malloc(size_optim); + hd_sm = (float *)malloc(size_scale); + hd_sv = (float *)malloc(size_scale); + hd_fp_m = (float *)malloc(size_param); + hd_fp_v = (float *)malloc(size_param); + hd_cpd_m = (float *)malloc(size_scale); + hd_cpd_v = (float *)malloc(size_scale); + hd_sqrtmm_m = (float *)malloc(size_scale); + hd_sqrtmm_v = (float *)malloc(size_scale); + + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) { + std::cout << "CUDA error occurred in data copy: " + << cudaGetErrorString(error) << std::endl; + return 0.; + } + + srand(0); + // random initialization for CPU tensor + for (int i = 0; i < M * N; i++) { + h_p[i] = (float)(rand() / (float(RAND_MAX) / 10)); + h_g[i] = (float)(rand() / (float(RAND_MAX) / 10)); + h_m[i] = (__nv_fp8_e4m3)(rand() / (float(RAND_MAX) / 10)); + h_v[i] = (__nv_fp8_e4m3)(rand() / (float(RAND_MAX) / 10)); + } + for (int i = 0; i < int(ceilf(M * N / 128.)); i++) { + h_sm[i] = (float)(rand() / (float(RAND_MAX) / 10)); + h_sv[i] = (float)(rand() / (float(RAND_MAX) / 10)); + h_cpd_m[i] = 2; + h_cpd_v[i] = 3. / 8.; + h_sqrtmm_m[i] = (float)(rand() / (float(RAND_MAX) / 10)); + h_sqrtmm_v[i] = (float)(rand() / (float(RAND_MAX) / 10)); + + printf("scale is %f\n", h_sm[i]); + } + for (int i = 0; i < M * N; i++) { + h_fp_m[i] = (float)h_m[i] * h_sm[int(floor(i / 128.))]; + h_fp_v[i] = (float)h_v[i] * h_sv[int(floor(i / 128.))]; + } + float beta1 = 0.9, beta2 = 0.95, lr = 4e-4, wd = 0.1, eps = 1e-8; + int step = 100, qgroup_size = 128, expand_min = 16; + + printFloatArrayToFile(h_p, M, N, "Past_CPU_param.txt"); + printFloatArrayToFile(h_g, M, N, "Past_CPU_grad.txt"); + printFloatArrayToFile(h_m, M, N, "Past_CPU_m1.txt"); + printFloatArrayToFile(h_sm, 1, int(ceilf(M * N / 128.)), "Past_CPU_ms.txt"); + printFloatArrayToFile(h_fp_m, M, N, "Past_CPU_mf.txt"); + printFloatArrayToFile(h_v, M, N, "Past_CPU_v2.txt"); + printFloatArrayToFile(h_sv, 1, int(ceilf(M * N / 128.)), "Past_CPU_vs.txt"); + printFloatArrayToFile(h_fp_v, M, N, "Past_CPU_vf.txt"); + + cudaMemcpy(d_p, h_p, size_param, cudaMemcpyHostToDevice); + cudaMemcpy(d_g, h_g, size_param, cudaMemcpyHostToDevice); + cudaMemcpy(d_m, h_m, size_optim, cudaMemcpyHostToDevice); + cudaMemcpy(d_v, h_v, size_optim, cudaMemcpyHostToDevice); + cudaMemcpy(d_sm, h_sm, size_scale, cudaMemcpyHostToDevice); + cudaMemcpy(d_sv, h_sv, size_scale, cudaMemcpyHostToDevice); + cudaMemcpy(d_cpd_m, h_cpd_m, size_scale, cudaMemcpyHostToDevice); + cudaMemcpy(d_cpd_v, h_cpd_v, size_scale, cudaMemcpyHostToDevice); + cudaMemcpy(d_sqrtmm_m, h_sqrtmm_m, size_scale, cudaMemcpyHostToDevice); + cudaMemcpy(d_sqrtmm_v, h_sqrtmm_v, size_scale, cudaMemcpyHostToDevice); + + fp8_adamw_cpu(h_p, h_g, h_fp_m, h_fp_v, beta1, beta2, lr, wd, eps, step, + qgroup_size, M, N); + + if (error != cudaSuccess) { + std::cout << "CUDA error occurred in data initialization: " + << cudaGetErrorString(error) << std::endl; + return 0.; + } + + myGPUKernel(d_p, d_g, d_m, d_sm, d_cpd_m, d_sqrtmm_m, d_v, d_sv, d_cpd_v, + d_sqrtmm_v, beta1, beta2, lr, wd, eps, step, qgroup_size, + expand_min, M, N); + + cudaMemcpy(hd_p, d_p, size_param, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_g, d_g, size_param, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_m, d_m, size_optim, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_v, d_v, size_optim, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_sm, d_sm, size_scale, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_sv, d_sv, size_scale, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_cpd_m, d_cpd_m, size_scale, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_cpd_v, d_cpd_v, size_scale, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_sqrtmm_m, d_sqrtmm_m, size_scale, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_sqrtmm_v, d_sqrtmm_v, size_scale, cudaMemcpyDeviceToHost); + + for (int i = 0; i < M * N; i++) { + hd_fp_m[i] = pow((float)hd_m[i] * hd_sm[int(floor(i / 128.))], + 1 / hd_cpd_m[int(floor(i / 128.))]) * + hd_sqrtmm_m[int(floor(i / 128.))]; + hd_fp_v[i] = pow((float)hd_v[i] * hd_sv[int(floor(i / 128.))], + 1 / hd_cpd_v[int(floor(i / 128.))]) * + hd_sqrtmm_v[int(floor(i / 128.))]; + } + printFloatArrayToFile(h_p, M, N, "CPU_param.txt"); + printFloatArrayToFile(hd_p, M, N, "GPU_param.txt"); + printFloatArrayToFile(h_g, M, N, "CPU_grad.txt"); + printFloatArrayToFile(hd_g, M, N, "GPU_grad.txt"); + printFloatArrayToFile(h_m, M, N, "CPU_m1.txt"); + printFloatArrayToFile(h_sm, 1, int(ceilf(M * N / 128.)), "CPU_ms.txt"); + printFloatArrayToFile(h_fp_m, M, N, "CPU_mf.txt"); + printFloatArrayToFile(hd_m, M, N, "GPU_m1.txt"); + printFloatArrayToFile(hd_sm, 1, int(ceilf(M * N / 128.)), "GPU_ms.txt"); + printFloatArrayToFile(hd_fp_m, M, N, "GPU_mf.txt"); + printFloatArrayToFile(h_v, M, N, "CPU_v2.txt"); + printFloatArrayToFile(h_sv, 1, int(ceilf(M * N / 128.)), "CPU_vs.txt"); + printFloatArrayToFile(h_fp_v, M, N, "CPU_vf.txt"); + printFloatArrayToFile(hd_v, M, N, "GPU_v2.txt"); + printFloatArrayToFile(hd_sv, 1, int(ceilf(M * N / 128.)), "GPU_vs.txt"); + printFloatArrayToFile(hd_fp_v, M, N, "GPU_vf.txt"); + + printFloatArrayToFile(hd_cpd_m, 1, int(ceilf(M * N / 128.)), "GPU_cpd_m.txt"); + printFloatArrayToFile(hd_cpd_v, 1, int(ceilf(M * N / 128.)), "GPU_cpd_v.txt"); + printFloatArrayToFile(hd_sqrtmm_m, 1, int(ceilf(M * N / 128.)), + "GPU_sqrtmm_m.txt"); + printFloatArrayToFile(hd_sqrtmm_v, 1, int(ceilf(M * N / 128.)), + "GPU_sqrtmm_v.txt"); + + return 0.; +} + +int main() { + const int M = 1, N = 7; + float max_error = testMaxError(myKernelLauncher, M, N); +} diff --git a/llava/model/coat/optimizer/kernels/csrc_origin_quantize/makefile b/llava/model/coat/optimizer/kernels/csrc_origin_quantize/makefile new file mode 100644 index 0000000000000000000000000000000000000000..ad4632ffe641f2043c2cbc4966dff66af2807c26 --- /dev/null +++ b/llava/model/coat/optimizer/kernels/csrc_origin_quantize/makefile @@ -0,0 +1,6 @@ +all: + nvcc nvcc_qoptim.cu -o nvcc_qoptim -gencode=arch=compute_90,code=compute_90 +run: + ./nvcc_qoptim +clean: + rm -f nvcc_qoptim diff --git a/llava/model/coat/optimizer/kernels/csrc_origin_quantize/nvcc_qoptim.cu b/llava/model/coat/optimizer/kernels/csrc_origin_quantize/nvcc_qoptim.cu new file mode 100644 index 0000000000000000000000000000000000000000..dde952654b71f11dc6de051a7089df38e14c6b57 --- /dev/null +++ b/llava/model/coat/optimizer/kernels/csrc_origin_quantize/nvcc_qoptim.cu @@ -0,0 +1,354 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace cg = cooperative_groups; +#define WARPSIZE 32 +#define QGROUPSIZE 128 +#define QUANT_MIN_VAL 1e-20 + +template +inline float fp8_dtype_max(const T &variable) { + if (std::is_same::value) { + return 448; + } else if (std::is_same::value) { + return 57344; + } else { + throw "Unsupported data format"; + } +} + +typedef enum { fp8_adamw } myCsrcKernels; + +void fp8_adamw_cpu(float *params, float *grads, float *fp_exp_avg, + float *fp_exp_avg_sq, float beta1, float beta2, float lr, + float wd, float eps, int step, int qgroup_size, int M, + int N) { + for (int idx = 0; idx < M * N; idx++) { + fp_exp_avg[idx] = beta1 * fp_exp_avg[idx] + (1 - beta1) * grads[idx]; + fp_exp_avg_sq[idx] = + beta2 * fp_exp_avg_sq[idx] + (1 - beta2) * grads[idx] * grads[idx]; + + const float correction1 = 1.0f - powf(beta1, step); + const float correction2_sqrt = sqrtf(1.0f - powf(beta2, step)); + + float denom = + (sqrtf(fp_exp_avg_sq[idx]) / correction2_sqrt + eps) * correction1; + float update = (fp_exp_avg[idx] / denom) + (wd * params[idx]); + params[idx] = params[idx] - (lr * update); + } +} + +template +void printFloatArrayToFile(T *array, int M, int N, + const std::string &outputFileName) { + std::ofstream outputFile(outputFileName); + if (!outputFile.is_open()) { + std::cout << "Failed to open the file." << std::endl; + return; + } + + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + int index = i * N + j; + outputFile << std::setw(10) << std::right << std::fixed + << std::setprecision(6) << (float)array[index] << " "; + if (j == N - 1) { + outputFile << "\n"; + } + } + } +} + +template +__global__ void fp8_adamw_csrc(scalar_t *__restrict__ params, + scalar_t *__restrict__ grads, + __nv_fp8_e4m3 *__restrict__ exp_avg, + float *__restrict__ scale_exp_avg, + __nv_fp8_e4m3 *__restrict__ exp_avg_sq, + float *__restrict__ scale_exp_avg_sq, + float beta1, float beta2, float lr, float wd, + float eps, int step, int qgroup_size, + int total_elements, int total_scale_elements) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int scale_idx = blockIdx.x; + + float float_exp_avg, float_exp_avg_sq; + float correction1, correction2_sqrt; + float denom, update; + + if (idx < total_elements) { + // dequantize the optimizer states + float_exp_avg = float(exp_avg[idx]) * scale_exp_avg[scale_idx]; + float_exp_avg_sq = float(exp_avg_sq[idx]) * scale_exp_avg_sq[scale_idx]; + + // calculation of optimizer.step() + float_exp_avg = beta1 * float_exp_avg + (1 - beta1) * grads[idx]; + float_exp_avg_sq = + beta2 * float_exp_avg_sq + (1 - beta2) * grads[idx] * grads[idx]; + + correction1 = 1.0f - powf(beta1, step); + correction2_sqrt = sqrtf(1.0f - powf(beta2, step)); + + denom = (sqrtf(float_exp_avg_sq) / correction2_sqrt + eps) * correction1; + update = (float_exp_avg / denom) + (wd * params[idx]); + + params[idx] = params[idx] - (lr * update); + } else { + float_exp_avg = 0.0f; + float_exp_avg_sq = 0.0f; + } + + //// quantize the first-order and second-order momentum + int wid = threadIdx.x / WARPSIZE; + + // reduction within a warp + + __shared__ float sharedFirstMaxVal[32]; + __shared__ float sharedSecondMaxVal[32]; + cg::thread_block_tile<32> warpTile = + cg::tiled_partition<32>(cg::this_thread_block()); + float firstMaxVal = fabsf(float_exp_avg); + float secondMaxVal = fabsf(float_exp_avg_sq); + + for (int i = warpTile.size() / 2; i > 0; i /= 2) { + float reduceFirstMaxVal = warpTile.shfl_down(firstMaxVal, i); + float reduceSecondMaxVal = warpTile.shfl_down(secondMaxVal, i); + firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal)); + secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal)); + // printf("First Max: %f\n", reduceFirstMaxVal); + } + int lane = warpTile.thread_rank(); + if (lane == 0) sharedFirstMaxVal[wid] = firstMaxVal; + if (lane == 0) sharedSecondMaxVal[wid] = secondMaxVal; + + __syncthreads(); + + // reduction within a block + __shared__ float shared_absmax_exp_avg; + __shared__ float shared_absmax_exp_avg_sq; + firstMaxVal = + (threadIdx.x < blockDim.x / warpSize) ? sharedFirstMaxVal[lane] : 0; + secondMaxVal = + (threadIdx.x < blockDim.x / warpSize) ? sharedSecondMaxVal[lane] : 0; + if (wid == 0) { + for (int offset = WARPSIZE / 2; offset > 0; offset /= 2) { + float reduceFirstMaxVal = + __shfl_down_sync(0xFFFFFFFF, firstMaxVal, offset); + float reduceSecondMaxVal = + __shfl_down_sync(0xFFFFFFFF, secondMaxVal, offset); + firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal)); + secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal)); + } + if (lane == 0) shared_absmax_exp_avg = firstMaxVal; + if (lane == 0) shared_absmax_exp_avg_sq = secondMaxVal; + } + + __syncthreads(); + + if (idx < total_elements) { + // float fp8MaxVal = fp8_dtype_max<__nv_fp8_e4m3>(exp_avg[idx]); + float fp8MaxVal = 448; + + shared_absmax_exp_avg = shared_absmax_exp_avg + QUANT_MIN_VAL; + shared_absmax_exp_avg_sq = shared_absmax_exp_avg_sq + QUANT_MIN_VAL; + + float new_scale_exp_avg = shared_absmax_exp_avg / fp8MaxVal; + float new_scale_exp_avg_sq = shared_absmax_exp_avg_sq / fp8MaxVal; + + // quantize the optimizer states + __nv_fp8_e4m3 exp_avg_new = + static_cast<__nv_fp8_e4m3>(float_exp_avg / new_scale_exp_avg); + __nv_fp8_e4m3 exp_avg_sq_new = + static_cast<__nv_fp8_e4m3>(float_exp_avg_sq / new_scale_exp_avg_sq); + // __half exp_avg_new = static_cast<__half>(float_exp_avg / + // new_scale_exp_avg); + // __half exp_avg_sq_new = static_cast<__half>(float_exp_avg_sq / + // new_scale_exp_avg_sq); + + // printf("idx: %d, float: %f, quantize: %f\n", idx, float_exp_avg, + // (float)exp_avg_new * new_scale_exp_avg); + + // store the output + exp_avg[idx] = exp_avg_new; + exp_avg_sq[idx] = exp_avg_sq_new; + scale_exp_avg[scale_idx] = new_scale_exp_avg; + scale_exp_avg_sq[scale_idx] = new_scale_exp_avg_sq; + } +} + +template +void myKernelLauncher(float *params, float *grads, __nv_fp8_e4m3 *exp_avg, + float *scale_exp_avg, __nv_fp8_e4m3 *exp_avg_sq, + float *scale_exp_avg_sq, float beta1, float beta2, + float lr, float wd, float eps, int step, int qgroup_size, + int M, int N) { + if (algo == fp8_adamw) { + const int block_dim = 128; + int grid_dim = (M * N + qgroup_size - 1) / block_dim; + const dim3 gridDim(grid_dim); + const dim3 blockDim(block_dim); + printf("Yes!\n"); + fp8_adamw_csrc<<>>( + params, grads, exp_avg, scale_exp_avg, exp_avg_sq, scale_exp_avg_sq, + beta1, beta2, lr, wd, eps, step, qgroup_size, M * N, + int(floor(M * N / 128.))); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) { + std::cout << "CUDA error occurred in kernel launch: " + << cudaGetErrorString(error) << std::endl; + return; + } + printf("Finish!\n"); + } +} + +float testMaxError(void (*myGPUKernel)(float *, float *, __nv_fp8_e4m3 *, + float *, __nv_fp8_e4m3 *, float *, float, + float, float, float, float, int, int, + int, int), + int M, int N) { + size_t size_param = M * N * sizeof(float); + size_t size_optim = M * N * sizeof(__nv_fp8_e4m3); + size_t size_scale = int(ceil(M * N / 128.)) * sizeof(float); + + // host tensor + float *h_p, *h_g; + __nv_fp8_e4m3 *h_m, *h_v; + float *h_sm, *h_sv; + float *h_fp_m, *h_fp_v; + + // device tensor + float *d_p, *d_g; + __nv_fp8_e4m3 *d_m, *d_v; + float *d_sm, *d_sv; + + // device tensor transfer to host + float *hd_p, *hd_g; + __nv_fp8_e4m3 *hd_m, *hd_v; + float *hd_sm, *hd_sv; + float *hd_fp_m, *hd_fp_v; + + h_p = (float *)malloc(size_param); + h_g = (float *)malloc(size_param); + h_m = (__nv_fp8_e4m3 *)malloc(size_optim); + h_v = (__nv_fp8_e4m3 *)malloc(size_optim); + h_sm = (float *)malloc(size_scale); + h_sv = (float *)malloc(size_scale); + h_fp_m = (float *)malloc(size_param); + h_fp_v = (float *)malloc(size_param); + cudaMalloc(&d_p, size_param); + cudaMalloc(&d_g, size_param); + cudaMalloc(&d_m, size_optim); + cudaMalloc(&d_v, size_optim); + cudaMalloc(&d_sm, size_scale); + cudaMalloc(&d_sv, size_scale); + hd_p = (float *)malloc(size_param); + hd_g = (float *)malloc(size_param); + hd_m = (__nv_fp8_e4m3 *)malloc(size_optim); + hd_v = (__nv_fp8_e4m3 *)malloc(size_optim); + hd_sm = (float *)malloc(size_scale); + hd_sv = (float *)malloc(size_scale); + hd_fp_m = (float *)malloc(size_param); + hd_fp_v = (float *)malloc(size_param); + + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) { + std::cout << "CUDA error occurred in data copy: " + << cudaGetErrorString(error) << std::endl; + return 0.; + } + + srand(0); + // random initialization for CPU tensor + for (int i = 0; i < M * N; i++) { + h_p[i] = (float)(rand() / (float(RAND_MAX) / 10)); + h_g[i] = (float)(rand() / (float(RAND_MAX) / 10)); + h_m[i] = (__nv_fp8_e4m3)(rand() / (float(RAND_MAX) / 10)); + h_v[i] = (__nv_fp8_e4m3)(rand() / (float(RAND_MAX) / 10)); + } + for (int i = 0; i < int(ceilf(M * N / 128.)); i++) { + h_sm[i] = (float)(rand() / (float(RAND_MAX) / 10)); + h_sv[i] = (float)(rand() / (float(RAND_MAX) / 10)); + printf("scale is %f\n", h_sm[i]); + } + for (int i = 0; i < M * N; i++) { + h_fp_m[i] = (float)h_m[i] * h_sm[int(floor(i / 128.))]; + h_fp_v[i] = (float)h_v[i] * h_sv[int(floor(i / 128.))]; + } + float beta1 = 0.9, beta2 = 0.95, lr = 4e-4, wd = 0.1, eps = 1e-8; + int step = 100, qgroup_size = 128; + + printFloatArrayToFile(h_p, M, N, "Past_CPU_param.txt"); + printFloatArrayToFile(h_g, M, N, "Past_CPU_grad.txt"); + printFloatArrayToFile(h_m, M, N, "Past_CPU_m1.txt"); + printFloatArrayToFile(h_sm, 1, int(ceilf(M * N / 128.)), "Past_CPU_ms.txt"); + printFloatArrayToFile(h_fp_m, M, N, "Past_CPU_mf.txt"); + printFloatArrayToFile(h_v, M, N, "Past_CPU_v2.txt"); + printFloatArrayToFile(h_sv, 1, int(ceilf(M * N / 128.)), "Past_CPU_vs.txt"); + printFloatArrayToFile(h_fp_v, M, N, "Past_CPU_vf.txt"); + + cudaMemcpy(d_p, h_p, size_param, cudaMemcpyHostToDevice); + cudaMemcpy(d_g, h_g, size_param, cudaMemcpyHostToDevice); + cudaMemcpy(d_m, h_m, size_optim, cudaMemcpyHostToDevice); + cudaMemcpy(d_v, h_v, size_optim, cudaMemcpyHostToDevice); + cudaMemcpy(d_sm, h_sm, size_scale, cudaMemcpyHostToDevice); + cudaMemcpy(d_sv, h_sv, size_scale, cudaMemcpyHostToDevice); + + fp8_adamw_cpu(h_p, h_g, h_fp_m, h_fp_v, beta1, beta2, lr, wd, eps, step, + qgroup_size, M, N); + + if (error != cudaSuccess) { + std::cout << "CUDA error occurred in data initialization: " + << cudaGetErrorString(error) << std::endl; + return 0.; + } + + myGPUKernel(d_p, d_g, d_m, d_sm, d_v, d_sv, beta1, beta2, lr, wd, eps, step, + qgroup_size, M, N); + + cudaMemcpy(hd_p, d_p, size_param, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_g, d_g, size_param, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_m, d_m, size_optim, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_v, d_v, size_optim, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_sm, d_sm, size_scale, cudaMemcpyDeviceToHost); + cudaMemcpy(hd_sv, d_sv, size_scale, cudaMemcpyDeviceToHost); + + for (int i = 0; i < M * N; i++) { + hd_fp_m[i] = (float)hd_m[i] * hd_sm[int(floor(i / 128.))]; + hd_fp_v[i] = (float)hd_v[i] * hd_sv[int(floor(i / 128.))]; + } + printFloatArrayToFile(h_p, M, N, "CPU_param.txt"); + printFloatArrayToFile(hd_p, M, N, "GPU_param.txt"); + printFloatArrayToFile(h_g, M, N, "CPU_grad.txt"); + printFloatArrayToFile(hd_g, M, N, "GPU_grad.txt"); + printFloatArrayToFile(h_m, M, N, "CPU_m1.txt"); + printFloatArrayToFile(h_sm, 1, int(ceilf(M * N / 128.)), "CPU_ms.txt"); + printFloatArrayToFile(h_fp_m, M, N, "CPU_mf.txt"); + printFloatArrayToFile(hd_m, M, N, "GPU_m1.txt"); + printFloatArrayToFile(hd_sm, 1, int(ceilf(M * N / 128.)), "GPU_ms.txt"); + printFloatArrayToFile(hd_fp_m, M, N, "GPU_mf.txt"); + printFloatArrayToFile(h_v, M, N, "CPU_v2.txt"); + printFloatArrayToFile(h_sv, 1, int(ceilf(M * N / 128.)), "CPU_vs.txt"); + printFloatArrayToFile(h_fp_v, M, N, "CPU_vf.txt"); + printFloatArrayToFile(hd_v, M, N, "GPU_v2.txt"); + printFloatArrayToFile(hd_sv, 1, int(ceilf(M * N / 128.)), "GPU_vs.txt"); + printFloatArrayToFile(hd_fp_v, M, N, "GPU_vf.txt"); + + return 0.; +} + +int main() { + const int M = 1, N = 7; + float max_error = testMaxError(myKernelLauncher, M, N); +} diff --git a/llava/model/coat/optimizer/kernels/fp8_adamw_cuda.cpp b/llava/model/coat/optimizer/kernels/fp8_adamw_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..57190c713309a19eafe35b4651e55d7c5bd576f7 --- /dev/null +++ b/llava/model/coat/optimizer/kernels/fp8_adamw_cuda.cpp @@ -0,0 +1,26 @@ +#include +#include + +void FP8_AdamW_cuda(torch::Tensor params, // parameter + torch::Tensor grads, // gradient + torch::Tensor exp_avg, // first order momentum + torch::Tensor scale_exp_avg, + torch::Tensor exp_avg_sq, // second order momentum + torch::Tensor scale_exp_avg_sq, float beta1, float beta2, + float lr, float wd, float eps, int step, + int qgroup_size // other parameters +); + +void FP8_AdamW(torch::Tensor params, // parameter + torch::Tensor grads, // gradient + torch::Tensor exp_avg, // first order momentum + torch::Tensor scale_exp_avg, + torch::Tensor exp_avg_sq, // second order momentum + torch::Tensor scale_exp_avg_sq, float beta1, float beta2, + float lr, float wd, float eps, int step, + int qgroup_size) { // other parameters + + FP8_AdamW_cuda(params, grads, exp_avg, scale_exp_avg, exp_avg_sq, + scale_exp_avg_sq, beta1, beta2, lr, wd, eps, step, + qgroup_size); +} diff --git a/llava/model/coat/optimizer/kernels/fp8_adamw_cuda_kernel.cu b/llava/model/coat/optimizer/kernels/fp8_adamw_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..9e652845af5d228237d3c8df758df1a4fe94bcd4 --- /dev/null +++ b/llava/model/coat/optimizer/kernels/fp8_adamw_cuda_kernel.cu @@ -0,0 +1,163 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define QUANT_MIN_VAL 1e-20 + +namespace cg = cooperative_groups; +#define WARPSIZE 32 + +template +__global__ void fp8_adamw_cuda_kernel( + scalar_t* __restrict__ params, scalar_t* __restrict__ grads, + __nv_fp8_e4m3* __restrict__ exp_avg, float* __restrict__ scale_exp_avg, + __nv_fp8_e4m3* __restrict__ exp_avg_sq, + float* __restrict__ scale_exp_avg_sq, float beta1, float beta2, float lr, + float wd, float eps, int step, int qgroup_size, int total_elements, + int total_scale_elements) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int scale_idx = blockIdx.x; + + float float_exp_avg, float_exp_avg_sq; + float correction1, correction2_sqrt; + float denom, update; + + if (idx < total_elements) { + // dequantize the optimizer states + float_exp_avg = float(exp_avg[idx]) * scale_exp_avg[scale_idx]; + float_exp_avg_sq = float(exp_avg_sq[idx]) * scale_exp_avg_sq[scale_idx]; + + // calculation of optimizer.step() + float_exp_avg = beta1 * float_exp_avg + (1 - beta1) * grads[idx]; + float_exp_avg_sq = + beta2 * float_exp_avg_sq + (1 - beta2) * grads[idx] * grads[idx]; + + correction1 = 1.0f - powf(beta1, step); + correction2_sqrt = sqrtf(1.0f - powf(beta2, step)); + + denom = (sqrtf(float_exp_avg_sq) / correction2_sqrt + eps) * correction1; + update = (float_exp_avg / denom) + (wd * params[idx]); + + params[idx] = params[idx] - (lr * update); + } else { + float_exp_avg = 0.0f; + float_exp_avg_sq = 0.0f; + } + + //// quantize the first-order and second-order momentum + int wid = threadIdx.x / WARPSIZE; + + // reduction within a warp + + __shared__ float sharedFirstMaxVal[32]; + __shared__ float sharedSecondMaxVal[32]; + cg::thread_block_tile<32> warpTile = + cg::tiled_partition<32>(cg::this_thread_block()); + float firstMaxVal = fabsf(float_exp_avg); + float secondMaxVal = fabsf(float_exp_avg_sq); + + for (int i = warpTile.size() / 2; i > 0; i /= 2) { + float reduceFirstMaxVal = warpTile.shfl_down(firstMaxVal, i); + float reduceSecondMaxVal = warpTile.shfl_down(secondMaxVal, i); + firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal)); + secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal)); + // printf("First Max: %f\n", reduceFirstMaxVal); + } + int lane = warpTile.thread_rank(); + if (lane == 0) sharedFirstMaxVal[wid] = firstMaxVal; + if (lane == 0) sharedSecondMaxVal[wid] = secondMaxVal; + + __syncthreads(); + + // reduction within a block + __shared__ float shared_absmax_exp_avg; + __shared__ float shared_absmax_exp_avg_sq; + firstMaxVal = + (threadIdx.x < blockDim.x / warpSize) ? sharedFirstMaxVal[lane] : 0; + secondMaxVal = + (threadIdx.x < blockDim.x / warpSize) ? sharedSecondMaxVal[lane] : 0; + if (wid == 0) { + for (int offset = WARPSIZE / 2; offset > 0; offset /= 2) { + float reduceFirstMaxVal = + __shfl_down_sync(0xFFFFFFFF, firstMaxVal, offset); + float reduceSecondMaxVal = + __shfl_down_sync(0xFFFFFFFF, secondMaxVal, offset); + firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal)); + secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal)); + } + if (lane == 0) shared_absmax_exp_avg = firstMaxVal; + if (lane == 0) shared_absmax_exp_avg_sq = secondMaxVal; + } + + __syncthreads(); + + if (idx < total_elements) { + // float fp8MaxVal = fp8_dtype_max<__nv_fp8_e4m3>(exp_avg[idx]); + float fp8MaxVal = 448; + + shared_absmax_exp_avg = shared_absmax_exp_avg + QUANT_MIN_VAL; + shared_absmax_exp_avg_sq = shared_absmax_exp_avg_sq + QUANT_MIN_VAL; + + float new_scale_exp_avg = shared_absmax_exp_avg / fp8MaxVal; + float new_scale_exp_avg_sq = shared_absmax_exp_avg_sq / fp8MaxVal; + + // quantize the optimizer states + __nv_fp8_e4m3 exp_avg_new = + static_cast<__nv_fp8_e4m3>(float_exp_avg / new_scale_exp_avg); + __nv_fp8_e4m3 exp_avg_sq_new = + static_cast<__nv_fp8_e4m3>(float_exp_avg_sq / new_scale_exp_avg_sq); + // __half exp_avg_new = static_cast<__half>(float_exp_avg / + // new_scale_exp_avg); + // __half exp_avg_sq_new = static_cast<__half>(float_exp_avg_sq / + // new_scale_exp_avg_sq); + + // printf("idx: %d, float: %f, quantize: %f\n", idx, float_exp_avg, + // (float)exp_avg_new * new_scale_exp_avg); + + // store the output + exp_avg[idx] = exp_avg_new; + exp_avg_sq[idx] = exp_avg_sq_new; + scale_exp_avg[scale_idx] = new_scale_exp_avg; + scale_exp_avg_sq[scale_idx] = new_scale_exp_avg_sq; + } +} + +void FP8_AdamW_cuda(torch::Tensor params, // parameter + torch::Tensor grads, // gradient + torch::Tensor exp_avg, // first order momentum + torch::Tensor scale_exp_avg, + torch::Tensor exp_avg_sq, // second order momentum + torch::Tensor scale_exp_avg_sq, float beta1, float beta2, + float lr, float wd, float eps, int step, + int qgroup_size) { // other parameters + + // CUDA Blocks + int total_elements = params.numel(); + int total_scale_elements = scale_exp_avg.numel(); + AT_ASSERTM(qgroup_size == 128, + "Only Support 128 per-group quantization currently"); + const int block_dim = 128; // This should equal to the qgroup_size + int grid_dim = (total_elements + qgroup_size - 1) / block_dim; + AT_ASSERTM(grid_dim == scale_exp_avg.numel()); + AT_ASSERTM(grid_dim == scale_exp_avg_sq.numel()); + const dim3 blocks(grid_dim); + + // Execution + AT_DISPATCH_FLOATING_TYPES_AND2( + at::kBFloat16, at::kHalf, params.scalar_type(), "fp8_adamw", ([&] { + fp8_adamw_cuda_kernel<<>>( + params.data_ptr(), grads.data_ptr(), + (__nv_fp8_e4m3*)exp_avg.data_ptr(), + scale_exp_avg.data_ptr(), + (__nv_fp8_e4m3*)exp_avg_sq.data_ptr(), + scale_exp_avg_sq.data_ptr(), beta1, beta2, lr, wd, eps, step, + qgroup_size, total_elements, total_scale_elements); + })); +} diff --git a/llava/model/coat/optimizer/kernels/fp8_adamw_expand_cuda.cpp b/llava/model/coat/optimizer/kernels/fp8_adamw_expand_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d50f96dc9d4dc186b660d6b6ff56ff9df1c2df89 --- /dev/null +++ b/llava/model/coat/optimizer/kernels/fp8_adamw_expand_cuda.cpp @@ -0,0 +1,34 @@ +#include +#include + +void FP8_AdamW_expand_cuda(torch::Tensor params, // parameter + torch::Tensor grads, // gradient + torch::Tensor exp_avg, // first order momentum + torch::Tensor scale_exp_avg, + torch::Tensor expand_exp_avg, + torch::Tensor sqrtminmax_exp_avg, + torch::Tensor exp_avg_sq, // second order momentum + torch::Tensor scale_exp_avg_sq, + torch::Tensor expand_exp_avg_sq, + torch::Tensor sqrtminmax_exp_avg_sq, float beta1, + float beta2, float lr, float wd, float eps, int step, + int qgroup_size, int expand_min // other parameters +); + +void FP8_AdamW_expand(torch::Tensor params, // parameter + torch::Tensor grads, // gradient + torch::Tensor exp_avg, // first order momentum + torch::Tensor scale_exp_avg, torch::Tensor expand_exp_avg, + torch::Tensor sqrtminmax_exp_avg, + torch::Tensor exp_avg_sq, // second order momentum + torch::Tensor scale_exp_avg_sq, + torch::Tensor expand_exp_avg_sq, + torch::Tensor sqrtminmax_exp_avg_sq, float beta1, + float beta2, float lr, float wd, float eps, int step, + int qgroup_size, int expand_min) { // other parameters + + FP8_AdamW_expand_cuda(params, grads, exp_avg, scale_exp_avg, expand_exp_avg, + sqrtminmax_exp_avg, exp_avg_sq, scale_exp_avg_sq, + expand_exp_avg_sq, sqrtminmax_exp_avg_sq, beta1, beta2, + lr, wd, eps, step, qgroup_size, expand_min); +} diff --git a/llava/model/coat/optimizer/kernels/fp8_adamw_expand_cuda_kernel.cu b/llava/model/coat/optimizer/kernels/fp8_adamw_expand_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..6130e3716fe5d73bf66cc6f915c3bf99a41f785b --- /dev/null +++ b/llava/model/coat/optimizer/kernels/fp8_adamw_expand_cuda_kernel.cu @@ -0,0 +1,247 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define QUANT_MIN_VAL 1e-20 + +namespace cg = cooperative_groups; +#define WARPSIZE 32 + +template +__global__ void fp8_adamw_cuda_expand_kernel( + scalar_t* __restrict__ params, scalar_t* __restrict__ grads, + __nv_fp8_e4m3* __restrict__ exp_avg, float* __restrict__ scale_exp_avg, + float* __restrict__ expand_exp_avg, float* __restrict__ sqrtminmax_exp_avg, + __nv_fp8_e4m3* __restrict__ exp_avg_sq, + float* __restrict__ scale_exp_avg_sq, float* __restrict__ expand_exp_avg_sq, + float* __restrict__ sqrtminmax_exp_avg_sq, float beta1, float beta2, + float lr, float wd, float eps, int step, int qgroup_size, int expand_min, + int total_elements, int total_scale_elements) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int scale_idx = blockIdx.x; + + float float_exp_avg, float_exp_avg_sq; + float correction1, correction2_sqrt; + float denom, update; + + if (idx < total_elements) { + // dequantize the optimizer states + float_exp_avg = float(exp_avg[idx]) * scale_exp_avg[scale_idx]; + int sign_exp_avg = 1 - 2 * signbit(float_exp_avg); + float_exp_avg = sign_exp_avg * + powf(fabsf(float_exp_avg), 1 / expand_exp_avg[scale_idx]) * + sqrtminmax_exp_avg[scale_idx]; + float_exp_avg_sq = float(exp_avg_sq[idx]) * scale_exp_avg_sq[scale_idx]; + float_exp_avg_sq = + powf(float_exp_avg_sq, 1 / expand_exp_avg_sq[scale_idx]) * + sqrtminmax_exp_avg_sq[scale_idx]; + + // calculation of optimizer.step() + float_exp_avg = beta1 * float_exp_avg + (1 - beta1) * grads[idx]; + float_exp_avg_sq = + beta2 * float_exp_avg_sq + (1 - beta2) * grads[idx] * grads[idx]; + + correction1 = 1.0f - powf(beta1, step); + correction2_sqrt = sqrtf(1.0f - powf(beta2, step)); + + denom = (sqrtf(float_exp_avg_sq) / correction2_sqrt + eps) * correction1; + update = (float_exp_avg / denom) + (wd * params[idx]); + params[idx] = params[idx] - (lr * update); + } else { + float_exp_avg = 0.0f; + float_exp_avg_sq = 0.0f; + } + + //// quantize the first-order and second-order momentum + int wid = threadIdx.x / WARPSIZE; + + // reduction within a warp + __shared__ float sharedFirstMaxVal[32]; + __shared__ float sharedFirstMinVal[32]; + __shared__ float sharedSecondMaxVal[32]; + __shared__ float sharedSecondMinVal[32]; + cg::thread_block_tile<32> warpTile = + cg::tiled_partition<32>(cg::this_thread_block()); + float firstMaxVal = fabsf(float_exp_avg); + float firstMinVal = fabsf(float_exp_avg); + float secondMaxVal = fabsf(float_exp_avg_sq); + float secondMinVal = fabsf(float_exp_avg_sq); + // Special Handel + if (idx >= total_elements) { + firstMinVal = __int_as_float(0x7f7fffff); + secondMinVal = __int_as_float(0x7f7fffff); + } + + for (int i = warpTile.size() / 2; i > 0; i /= 2) { + float reduceFirstMaxVal = warpTile.shfl_down(firstMaxVal, i); + float reduceFirstMinVal = warpTile.shfl_down(firstMinVal, i); + float reduceSecondMaxVal = warpTile.shfl_down(secondMaxVal, i); + float reduceSecondMinVal = warpTile.shfl_down(secondMinVal, i); + firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal)); + firstMinVal = fmin(firstMinVal, fabsf(reduceFirstMinVal)); + secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal)); + secondMinVal = fmin(secondMinVal, fabsf(reduceSecondMinVal)); + // printf("First Max: %f\n", reduceFirstMaxVal); + } + int lane = warpTile.thread_rank(); + if (lane == 0) { + sharedFirstMaxVal[wid] = firstMaxVal; + sharedFirstMinVal[wid] = firstMinVal; + sharedSecondMaxVal[wid] = secondMaxVal; + sharedSecondMinVal[wid] = secondMinVal; + } + + __syncthreads(); + + // reduction within a block + __shared__ float shared_absmax_exp_avg; + __shared__ float shared_absmin_exp_avg; + __shared__ float shared_absmax_exp_avg_sq; + __shared__ float shared_absmin_exp_avg_sq; + firstMaxVal = + (threadIdx.x < blockDim.x / warpSize) ? sharedFirstMaxVal[lane] : 0; + firstMinVal = + (threadIdx.x < blockDim.x / warpSize) ? sharedFirstMinVal[lane] : 1e9; + secondMaxVal = + (threadIdx.x < blockDim.x / warpSize) ? sharedSecondMaxVal[lane] : 0; + secondMinVal = + (threadIdx.x < blockDim.x / warpSize) ? sharedSecondMinVal[lane] : 1e9; + if (wid == 0) { + for (int offset = WARPSIZE / 2; offset > 0; offset /= 2) { + float reduceFirstMaxVal = + __shfl_down_sync(0xFFFFFFFF, firstMaxVal, offset); + float reduceFirstMinVal = + __shfl_down_sync(0xFFFFFFFF, firstMinVal, offset); + float reduceSecondMaxVal = + __shfl_down_sync(0xFFFFFFFF, secondMaxVal, offset); + float reduceSecondMinVal = + __shfl_down_sync(0xFFFFFFFF, secondMinVal, offset); + firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal)); + firstMinVal = fmin(firstMinVal, fabsf(reduceFirstMinVal)); + secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal)); + secondMinVal = fmin(secondMinVal, fabsf(reduceSecondMinVal)); + } + if (lane == 0) { + shared_absmax_exp_avg = firstMaxVal; + shared_absmin_exp_avg = firstMinVal; + shared_absmax_exp_avg_sq = secondMaxVal; + shared_absmin_exp_avg_sq = secondMinVal; + } + } + + __syncthreads(); + + if (idx < total_elements) { + // float fp8MaxVal = fp8_dtype_max<__nv_fp8_e4m3>(exp_avg[idx]); + // scaling factor before expanding + float fp8MaxVal = 448; + + // dynamic exponent quantization part + firstMaxVal = shared_absmax_exp_avg + QUANT_MIN_VAL; + firstMinVal = shared_absmin_exp_avg + QUANT_MIN_VAL; + secondMaxVal = shared_absmax_exp_avg_sq + QUANT_MIN_VAL; + secondMinVal = shared_absmin_exp_avg_sq + QUANT_MIN_VAL; + + // calculate the ratio and make the scale to center + float firstRatio = firstMaxVal / firstMinVal; + float secondRatio = secondMaxVal / secondMinVal; + float firstSqrtMinMax = sqrt(firstMaxVal * firstMinVal); + float secondSqrtMinMax = sqrt(secondMaxVal * secondMinVal); + + // printf("Max %f, Min %f, Origin %f \n", firstMaxVal, firstMinVal, + // float_exp_avg); + + // since we use x^k expander, calculate the optimal expanding factor + float ratioUpperBound = fp8MaxVal * fp8MaxVal / 2; + float firstExp = + floor((log2f(ratioUpperBound) / log2f(firstRatio)) * expand_min) / + expand_min; // expand_min is set to 8 for example, then the firstExp is + // the multiple of 1/8 + float secondExp = + floor((log2f(ratioUpperBound) / log2f(secondRatio)) * expand_min) / + expand_min; + + int sign_exp_avg = 1 - 2 * signbit(float_exp_avg); + float_exp_avg = + sign_exp_avg * powf(fabsf(float_exp_avg) / firstSqrtMinMax, firstExp); + float_exp_avg_sq = powf(float_exp_avg_sq / secondSqrtMinMax, secondExp); + + // correspondingly, change the scaling factor + float new_scale_exp_avg = + powf(firstMaxVal / firstSqrtMinMax, firstExp) / fp8MaxVal; + float new_scale_exp_avg_sq = + powf(secondMaxVal / secondSqrtMinMax, secondExp) / fp8MaxVal; + + // quantize the optimizer states + __nv_fp8_e4m3 exp_avg_new = + static_cast<__nv_fp8_e4m3>(float_exp_avg / new_scale_exp_avg); + __nv_fp8_e4m3 exp_avg_sq_new = + static_cast<__nv_fp8_e4m3>(float_exp_avg_sq / new_scale_exp_avg_sq); + // __half exp_avg_new = static_cast<__half>(float_exp_avg / + // new_scale_exp_avg); + // __half exp_avg_sq_new = static_cast<__half>(float_exp_avg_sq / + // new_scale_exp_avg_sq); + + // printf("idx: %d, float: %f, quantize: %f\n", idx, float_exp_avg, + // (float)exp_avg_new * new_scale_exp_avg); + + // store the output + exp_avg[idx] = exp_avg_new; + exp_avg_sq[idx] = exp_avg_sq_new; + scale_exp_avg[scale_idx] = new_scale_exp_avg; + scale_exp_avg_sq[scale_idx] = new_scale_exp_avg_sq; + expand_exp_avg[scale_idx] = firstExp; + expand_exp_avg_sq[scale_idx] = secondExp; + sqrtminmax_exp_avg[scale_idx] = firstSqrtMinMax; + sqrtminmax_exp_avg_sq[scale_idx] = secondSqrtMinMax; + } +} + +void FP8_AdamW_expand_cuda(torch::Tensor params, // parameter + torch::Tensor grads, // gradient + torch::Tensor exp_avg, // first order momentum + torch::Tensor scale_exp_avg, + torch::Tensor expand_exp_avg, + torch::Tensor sqrtminmax_exp_avg, + torch::Tensor exp_avg_sq, // second order momentum + torch::Tensor scale_exp_avg_sq, + torch::Tensor expand_exp_avg_sq, + torch::Tensor sqrtminmax_exp_avg_sq, float beta1, + float beta2, float lr, float wd, float eps, int step, + int qgroup_size, + int expand_min) { // other parameters + + // CUDA Blocks + int total_elements = params.numel(); + int total_scale_elements = scale_exp_avg.numel(); + AT_ASSERTM(qgroup_size == 128, + "Only Support 128 per-group quantization currently"); + const int block_dim = 128; // This should equal to the qgroup_size + int grid_dim = (total_elements + qgroup_size - 1) / block_dim; + AT_ASSERTM(grid_dim == scale_exp_avg.numel()); + AT_ASSERTM(grid_dim == scale_exp_avg_sq.numel()); + const dim3 blocks(grid_dim); + + // Execution + AT_DISPATCH_FLOATING_TYPES_AND2( + at::kBFloat16, at::kHalf, params.scalar_type(), "fp8_adamw", ([&] { + fp8_adamw_cuda_expand_kernel<<>>( + params.data_ptr(), grads.data_ptr(), + (__nv_fp8_e4m3*)exp_avg.data_ptr(), + scale_exp_avg.data_ptr(), expand_exp_avg.data_ptr(), + sqrtminmax_exp_avg.data_ptr(), + (__nv_fp8_e4m3*)exp_avg_sq.data_ptr(), + scale_exp_avg_sq.data_ptr(), + expand_exp_avg_sq.data_ptr(), + sqrtminmax_exp_avg_sq.data_ptr(), beta1, beta2, lr, wd, eps, + step, qgroup_size, expand_min, total_elements, + total_scale_elements); + })); +} diff --git a/llava/model/coat/optimizer/kernels/include/fp8_adamw.h b/llava/model/coat/optimizer/kernels/include/fp8_adamw.h new file mode 100644 index 0000000000000000000000000000000000000000..c04970da0dd564ea27e1e6203a38206d9506cac8 --- /dev/null +++ b/llava/model/coat/optimizer/kernels/include/fp8_adamw.h @@ -0,0 +1,14 @@ +#ifndef FP8_ADAMW +#define FP8_ADAMW + +#include + +void FP8_AdamW(torch::Tensor params, // parameter + torch::Tensor grads, // gradient + torch::Tensor exp_avg, // first order momentum + torch::Tensor scale_exp_avg, + torch::Tensor exp_avg_sq, // second order momentum + torch::Tensor scale_exp_avg_sq, float beta1, float beta2, + float lr, float wd, float eps, int step, int qgroup_size); + +#endif // FP8_ADAMW diff --git a/llava/model/coat/optimizer/kernels/include/fp8_adamw_expand.h b/llava/model/coat/optimizer/kernels/include/fp8_adamw_expand.h new file mode 100644 index 0000000000000000000000000000000000000000..e1fb69825933154b1a867efaecd54cbdf100201b --- /dev/null +++ b/llava/model/coat/optimizer/kernels/include/fp8_adamw_expand.h @@ -0,0 +1,18 @@ +#ifndef FP8_ADAMW_CONPAND +#define FP8_ADAMW_CONPAND + +#include + +void FP8_AdamW_expand(torch::Tensor params, // parameter + torch::Tensor grads, // gradient + torch::Tensor exp_avg, // first order momentum + torch::Tensor scale_exp_avg, torch::Tensor expand_exp_avg, + torch::Tensor sqrtminmax_exp_avg, + torch::Tensor exp_avg_sq, // second order momentum + torch::Tensor scale_exp_avg_sq, + torch::Tensor expand_exp_avg_sq, + torch::Tensor sqrtminmax_exp_avg_sq, float beta1, + float beta2, float lr, float wd, float eps, int step, + int qgroup_size, int expand_min); + +#endif // FP8_ADAMW_CONPAND diff --git a/llava/model/coat/optimizer/kernels/setup.py b/llava/model/coat/optimizer/kernels/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..8453efeaa472f3b094a2c157acbebf4c571eb09d --- /dev/null +++ b/llava/model/coat/optimizer/kernels/setup.py @@ -0,0 +1,56 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name="qoptim_cuda", + ext_modules=[ + CUDAExtension( + name="qoptim_cuda", + sources=[ + "fp8_adamw_cuda.cpp", + "fp8_adamw_cuda_kernel.cu", + "fp8_adamw_expand_cuda.cpp", + "fp8_adamw_expand_cuda_kernel.cu", + "bindings.cpp", + ], + # include_dirs=[ + # 'include' + # ], + extra_compile_args={ + "nvcc": [ + "-O3", + "-std=c++17", + "-gencode=arch=compute_90,code=compute_90", + "-DTORCH_USE_CUDA_DSA", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + ] + }, + ), + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/llava/model/configuration_llava.py b/llava/model/configuration_llava.py new file mode 100644 index 0000000000000000000000000000000000000000..b413354ab70581faf619a632e1934a1a956250d0 --- /dev/null +++ b/llava/model/configuration_llava.py @@ -0,0 +1,131 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Literal, Optional + +from pydantic import BaseModel, Field +from transformers import PretrainedConfig + + +class LlavaConfig(PretrainedConfig): + model_type = "llava" + + def __init__( + self, + llm_cfg=None, + vision_tower_cfg=None, + speech_tower_cfg=None, + sound_tower_cfg=None, + mm_projector_cfg=None, + speech_mm_projector_cfg=None, + sound_mm_projector_cfg=None, + architectures=None, + resume_path=None, + hidden_size=None, + mm_hidden_size=None, + speech_hidden_size=None, + sound_hidden_size=None, + image_aspect_ratio=None, + num_video_frames=None, + fps=None, + mm_vision_select_layer=None, + mm_vision_select_feature=None, + mm_use_im_start_end=False, + mm_use_im_patch_token=False, + mm_projector_lr=None, + speech_mm_projector_lr=None, + sound_mm_projector_lr=None, + vision_tower_lr=None, + speech_tower_lr=None, + sound_tower_lr=None, + vision_resolution=None, + interpolate_mode=None, + s2=None, + dynamic_s2=None, + s2_scales=None, + s2_max_split_size=None, + s2_resize_output_to_scale_idx=0, + min_tiles: Optional[int] = 1, + max_tiles: Optional[int] = 12, + video_max_tiles: Optional[int] = 1, + num_time_tokens=None, + time_token_format=None, + image_encoder: str = '{"_target_": "llava.model.encoders.BasicImageEncoder"}', + video_encoder: str = '{"_target_": "llava.model.encoders.BasicVideoEncoder"}', + speech_encoder: str = '{"_target_": "llava.model.encoders.BasicSpeechEncoder"}', + sound_encoder: str = '{"_target_": "llava.model.encoders.BasicSoundEncoder"}', + **kwargs, + ): + super().__init__() + self.architectures = architectures + self.llm_cfg = llm_cfg + self.vision_tower_cfg = vision_tower_cfg + self.speech_tower_cfg = speech_tower_cfg + self.sound_tower_cfg = sound_tower_cfg + self.mm_projector_cfg = mm_projector_cfg + self.speech_mm_projector_cfg = speech_mm_projector_cfg + self.sound_mm_projector_cfg = sound_mm_projector_cfg + self.resume_path = resume_path + + self.hidden_size = hidden_size + self.mm_hidden_size = mm_hidden_size + self.speech_hidden_size = speech_hidden_size + self.sound_hidden_size = sound_hidden_size + self.image_aspect_ratio = image_aspect_ratio + self.num_video_frames = num_video_frames + self.fps = fps + self.mm_vision_select_layer = mm_vision_select_layer + self.mm_vision_select_feature = mm_vision_select_feature + self.mm_use_im_start_end = mm_use_im_start_end + self.mm_use_im_patch_token = mm_use_im_patch_token + self.mm_projector_lr = mm_projector_lr + self.speech_mm_projector_lr = speech_mm_projector_lr + self.sound_mm_projector_lr = sound_mm_projector_lr + self.vision_tower_lr = vision_tower_lr + self.speech_tower_lr = speech_tower_lr + self.sound_tower_lr = sound_tower_lr + self.vision_resolution = vision_resolution + self.interpolate_mode = interpolate_mode + self.s2 = s2 + self.dynamic_s2 = dynamic_s2 + self.s2_scales = s2_scales + self.s2_max_split_size = s2_max_split_size + self.s2_resize_output_to_scale_idx = s2_resize_output_to_scale_idx + self.min_tiles = min_tiles + self.max_tiles = max_tiles + self.video_max_tiles = video_max_tiles + self.num_time_tokens = num_time_tokens + self.time_token_format = time_token_format + + self.image_encoder = image_encoder + self.video_encoder = video_encoder + self.speech_encoder = speech_encoder + self.sound_encoder = sound_encoder + + +class JsonSchemaResponseFormat(BaseModel): + schema_: str = Field(alias="schema") + + +class ResponseFormat(BaseModel): + type: Literal["text", "json_object", "json_schema"] + json_schema: Optional[JsonSchemaResponseFormat] = None diff --git a/llava/model/deprecate_consolidate.py b/llava/model/deprecate_consolidate.py new file mode 100644 index 0000000000000000000000000000000000000000..ebb9619df19e4e8791920b4c22d38763538454b1 --- /dev/null +++ b/llava/model/deprecate_consolidate.py @@ -0,0 +1,52 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Usage: +python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate +""" +import argparse + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llava.model import * +from llava.model.utils import auto_upgrade + + +def consolidate_ckpt(src_path, dst_path): + print("Loading model") + auto_upgrade(src_path) + src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) + src_model.save_pretrained(dst_path) + src_tokenizer.save_pretrained(dst_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--src", type=str, required=True) + parser.add_argument("--dst", type=str, required=True) + + args = parser.parse_args() + + consolidate_ckpt(args.src, args.dst) diff --git a/llava/model/encoders/__init__.py b/llava/model/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e80d23d6693e11f0779e342c9b6384d78b4c79 --- /dev/null +++ b/llava/model/encoders/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from .image import * +from .video import * +from .speech import * +from .sound import * + diff --git a/llava/model/encoders/__pycache__/__init__.cpython-310.pyc b/llava/model/encoders/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51821273a8633724bacf886f82c5d3f6dccef0f2 Binary files /dev/null and b/llava/model/encoders/__pycache__/__init__.cpython-310.pyc differ diff --git a/llava/model/encoders/__pycache__/__init__.cpython-311.pyc b/llava/model/encoders/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91555cc8635606593bafd5cbfc61fff128718d6e Binary files /dev/null and b/llava/model/encoders/__pycache__/__init__.cpython-311.pyc differ diff --git a/llava/model/encoders/__pycache__/base.cpython-310.pyc b/llava/model/encoders/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1a5e6a8761723bbebdd2cec99f1d26496e05ce0 Binary files /dev/null and b/llava/model/encoders/__pycache__/base.cpython-310.pyc differ diff --git a/llava/model/encoders/__pycache__/base.cpython-311.pyc b/llava/model/encoders/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e0467e667be5f975f22ccd0d080fe75c8ea560c Binary files /dev/null and b/llava/model/encoders/__pycache__/base.cpython-311.pyc differ diff --git a/llava/model/encoders/base.py b/llava/model/encoders/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4c9230c954a0f711b063e9b9929cbe0e3d40b55a --- /dev/null +++ b/llava/model/encoders/base.py @@ -0,0 +1,19 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from torch import nn + +__all__ = ["BaseEncoder"] + + +class BaseEncoder(nn.Module): + def __init__(self, parent: nn.Module) -> None: + super().__init__() + self._parent = [parent] + + @property + def parent(self) -> nn.Module: + return self._parent[0] diff --git a/llava/model/encoders/image/__init__.py b/llava/model/encoders/image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0df4b41c3ec1283ddeeb6024aa584e6cc3a9b4fb --- /dev/null +++ b/llava/model/encoders/image/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from .basic import * diff --git a/llava/model/encoders/image/basic.py b/llava/model/encoders/image/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..6428e38f56ffe421edd25bf45f378b25849874f8 --- /dev/null +++ b/llava/model/encoders/image/basic.py @@ -0,0 +1,56 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from functools import partial +from typing import Any, Dict, List, Optional + +import torch + +from llava.model.encoders.base import BaseEncoder + +__all__ = ["BasicImageEncoder"] + + +class BasicImageEncoder(BaseEncoder): + def __init__( + self, + parent: torch.nn.Module, + start_tokens: Optional[str] = None, + end_tokens: Optional[str] = "\n", + ) -> None: + super().__init__(parent) + self.start_tokens = start_tokens + self.end_tokens = end_tokens + + def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]: + if tokens is None: + return None + token_ids = self.parent.tokenizer(tokens).input_ids + token_ids = torch.tensor(token_ids, device=self.parent.device) + return self.parent.llm.model.embed_tokens(token_ids) + + def _process_features( + self, + features: torch.Tensor, + start_token_embeds: Optional[torch.Tensor], + end_token_embeds: Optional[torch.Tensor], + ) -> torch.Tensor: + if start_token_embeds is not None: + features = torch.cat([start_token_embeds, features], dim=0) + if end_token_embeds is not None: + features = torch.cat([features, end_token_embeds], dim=0) + return features + + def forward(self, images: List[torch.Tensor], config: Dict[str, Any], frame_times=None) -> List[torch.Tensor]: + images = torch.stack(images, dim=0) + features = self.parent.encode_images(images, block_sizes=config.get("block_sizes")) + + process_features = partial( + self._process_features, + start_token_embeds=self.embed_tokens(self.start_tokens), + end_token_embeds=self.embed_tokens(self.end_tokens), + ) + return [process_features(f) for f in features] diff --git a/llava/model/encoders/sound/__init__.py b/llava/model/encoders/sound/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0df4b41c3ec1283ddeeb6024aa584e6cc3a9b4fb --- /dev/null +++ b/llava/model/encoders/sound/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from .basic import * diff --git a/llava/model/encoders/sound/basic.py b/llava/model/encoders/sound/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..beb5ba5af1ba40fc9a9cba5b41e37673a6cdfeb2 --- /dev/null +++ b/llava/model/encoders/sound/basic.py @@ -0,0 +1,57 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from functools import partial +from typing import Any, Dict, List, Optional + +import torch + +from llava.model.encoders.base import BaseEncoder + +__all__ = ["BasicSoundEncoder"] + + +class BasicSoundEncoder(BaseEncoder): + def __init__( + self, + parent: torch.nn.Module, + start_tokens: Optional[str] = None, + end_tokens: Optional[str] = "\n", + ) -> None: + super().__init__(parent) + self.start_tokens = start_tokens + self.end_tokens = end_tokens + + def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]: + if tokens is None: + return None + token_ids = self.parent.tokenizer(tokens).input_ids + token_ids = torch.tensor(token_ids, device=self.parent.device) + return self.parent.llm.model.embed_tokens(token_ids) + + def _process_features( + self, + features: torch.Tensor, + start_token_embeds: Optional[torch.Tensor], + end_token_embeds: Optional[torch.Tensor], + ) -> torch.Tensor: + features = features.to(self.parent.device) + if start_token_embeds is not None: + features = torch.cat([start_token_embeds, features], dim=0) + if end_token_embeds is not None: + features = torch.cat([features, end_token_embeds], dim=0) + return features + + def forward(self, sounds: List[torch.Tensor], config: Dict[str, Any], masks: Dict[str, Any]) -> List[torch.Tensor]: + sounds = torch.stack(sounds, dim=0) + masks = torch.stack(masks, dim=0) + features = self.parent.encode_sound(sounds, masks) + process_features = partial( + self._process_features, + start_token_embeds=self.embed_tokens(self.start_tokens), + end_token_embeds=self.embed_tokens(self.end_tokens), + ) + return [process_features(f) for f in features] diff --git a/llava/model/encoders/speech/__init__.py b/llava/model/encoders/speech/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0df4b41c3ec1283ddeeb6024aa584e6cc3a9b4fb --- /dev/null +++ b/llava/model/encoders/speech/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from .basic import * diff --git a/llava/model/encoders/speech/basic.py b/llava/model/encoders/speech/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..68226501c5fcfc8d0cc6030ca71b324a7cb7f360 --- /dev/null +++ b/llava/model/encoders/speech/basic.py @@ -0,0 +1,55 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from functools import partial +from typing import Any, Dict, List, Optional + +import torch + +from llava.model.encoders.base import BaseEncoder + +__all__ = ["BasicSpeechEncoder"] + + +class BasicSpeechEncoder(BaseEncoder): + def __init__( + self, + parent: torch.nn.Module, + start_tokens: Optional[str] = None, + end_tokens: Optional[str] = "\n", + ) -> None: + super().__init__(parent) + self.start_tokens = start_tokens + self.end_tokens = end_tokens + + def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]: + if tokens is None: + return None + token_ids = self.parent.tokenizer(tokens).input_ids + token_ids = torch.tensor(token_ids, device=self.parent.device) + return self.parent.llm.model.embed_tokens(token_ids) + + def _process_features( + self, + features: torch.Tensor, + start_token_embeds: Optional[torch.Tensor], + end_token_embeds: Optional[torch.Tensor], + ) -> torch.Tensor: + if start_token_embeds is not None: + features = torch.cat([start_token_embeds, features], dim=0) + if end_token_embeds is not None: + features = torch.cat([features, end_token_embeds], dim=0) + return features + + def forward(self, speeches: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]: + speeches = torch.stack(speeches, dim=0) + features = self.parent.encode_speech(speeches) + process_features = partial( + self._process_features, + start_token_embeds=self.embed_tokens(self.start_tokens), + end_token_embeds=self.embed_tokens(self.end_tokens), + ) + return [process_features(f) for f in features] diff --git a/llava/model/encoders/video/__init__.py b/llava/model/encoders/video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b00796ffe81fd8d1cb50cd604a353e8c4e8199ad --- /dev/null +++ b/llava/model/encoders/video/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from .basic import * +from .tsp import * diff --git a/llava/model/encoders/video/basic.py b/llava/model/encoders/video/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..8b84c6917874686d071ef60a5bd6f2cab0a7f9e4 --- /dev/null +++ b/llava/model/encoders/video/basic.py @@ -0,0 +1,59 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from functools import partial +from typing import Any, Dict, List, Optional + +import torch + +from llava.model.encoders.base import BaseEncoder + +__all__ = ["BasicVideoEncoder"] + + +class BasicVideoEncoder(BaseEncoder): + def __init__( + self, + parent: torch.nn.Module, + start_tokens: Optional[str] = None, + end_tokens: Optional[str] = "\n", + ) -> None: + super().__init__(parent) + self.start_tokens = start_tokens + self.end_tokens = end_tokens + + def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]: + if tokens is None: + return None + token_ids = self.parent.tokenizer(tokens).input_ids + token_ids = torch.tensor(token_ids, device=self.parent.device) + return self.parent.llm.model.embed_tokens(token_ids) + + def _process_features( + self, + features: torch.Tensor, + start_token_embeds: Optional[torch.Tensor], + end_token_embeds: Optional[torch.Tensor], + ) -> torch.Tensor: + if start_token_embeds is not None: + start_embeds = torch.stack([start_token_embeds] * features.shape[0], dim=0) + features = torch.cat([start_embeds, features], dim=1) + if end_token_embeds is not None: + end_embeds = torch.stack([end_token_embeds] * features.shape[0], dim=0) + features = torch.cat([features, end_embeds], dim=1) + return features.flatten(0, 1) + + def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]: + num_frames = [video.shape[0] for video in videos] + images = torch.cat(videos, dim=0) + features = self.parent.encode_images(images) + features = torch.split(features, num_frames) + process_features = partial( + self._process_features, + start_token_embeds=self.embed_tokens(self.start_tokens), + end_token_embeds=self.embed_tokens(self.end_tokens), + ) + return [process_features(f) for f in features] diff --git a/llava/model/encoders/video/tsp.py b/llava/model/encoders/video/tsp.py new file mode 100644 index 0000000000000000000000000000000000000000..efa06530b3aea80a994efb5bd3d5db70f46bd577 --- /dev/null +++ b/llava/model/encoders/video/tsp.py @@ -0,0 +1,70 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from functools import partial +from typing import Any, Dict, List, Optional, Tuple + +import torch + +from .basic import BasicVideoEncoder + +__all__ = ["TSPVideoEncoder"] + + +def pool(x: torch.Tensor, size: int, dim: int) -> torch.Tensor: + return x.view(x.shape[:dim] + (-1, size) + x.shape[dim + 1 :]).mean(dim + 1) + + +class TSPVideoEncoder(BasicVideoEncoder): + def __init__( + self, + parent: torch.nn.Module, + pool_sizes: List[Tuple[int, int, int]], + start_tokens: Optional[str] = None, + end_tokens: Optional[str] = "\n", + sep_tokens: Optional[str] = None, + ) -> None: + super().__init__(parent, start_tokens=start_tokens, end_tokens=end_tokens) + self.pool_sizes = pool_sizes + self.sep_tokens = sep_tokens + + def _process_features( + self, + inputs: torch.Tensor, + start_token_embeds: Optional[torch.Tensor], + end_token_embeds: Optional[torch.Tensor], + sep_token_embeds: Optional[torch.Tensor], + ) -> torch.Tensor: + nt, ns = inputs.shape[:2] + nl = int(ns**0.5) + outputs = [] + for pool_size in self.pool_sizes: + features = inputs.view(nt, nl, nl, -1) + for dim, p in enumerate(pool_size): + features = pool(features, p, dim=dim) + features = features.flatten(1, 2) + features = super()._process_features( + features, + start_token_embeds=start_token_embeds, + end_token_embeds=end_token_embeds, + ) + if sep_token_embeds is not None: + features = torch.cat([features, sep_token_embeds], dim=0) + outputs.append(features) + return torch.cat(outputs, dim=0) + + def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]: + num_frames = [video.shape[0] for video in videos] + images = torch.cat(videos, dim=0) + features = self.parent.encode_images(images) + features = torch.split(features, num_frames) + process_features = partial( + self._process_features, + start_token_embeds=self.embed_tokens(self.start_tokens), + end_token_embeds=self.embed_tokens(self.end_tokens), + sep_token_embeds=self.embed_tokens(self.sep_tokens), + ) + return [process_features(f) for f in features] diff --git a/llava/model/language_model/builder.py b/llava/model/language_model/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..6936800c3c06666fc9d3d02f7424d0eb6409ea39 --- /dev/null +++ b/llava/model/language_model/builder.py @@ -0,0 +1,223 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import math +import os +import os.path as osp +import warnings +from dataclasses import asdict +from typing import Tuple + +import torch +from huggingface_hub import file_exists, repo_exists +from huggingface_hub.utils import HFValidationError +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForVision2Seq, + AutoTokenizer, + PretrainedConfig, + PreTrainedModel, + PreTrainedTokenizer, +) + + +from llava.constants import MEDIA_TOKENS +from llava.model.utils import packing +from llava.utils.logging import logger +from llava.utils.tokenizer import infer_stop_tokens + + +def has_tokenizer(repo_id_or_path: str) -> bool: + # Check if the tokenizer is in a local directory + if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")): + return True + + # Check if the tokenizer is in a Hugging Face Hub repo + try: + return repo_exists(repo_id_or_path) and file_exists(repo_id_or_path, "tokenizer_config.json") + except HFValidationError: + return False + + +def context_length_extension(config): + orig_ctx_len = getattr(config, "max_position_embeddings", None) + model_max_length = getattr(config, "model_max_length", None) + if orig_ctx_len and model_max_length > orig_ctx_len: + print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}") + scaling_factor = float(math.ceil(model_max_length / orig_ctx_len)) + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + return config + + +def build_llm_and_tokenizer( + model_name_or_path: str, + config: PretrainedConfig, + attn_implementation=None, + model_max_length=None, + *args, + **kwargs, +) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: + # print(model_name_or_path) + llm_cfg = AutoConfig.from_pretrained(model_name_or_path) + llm_cfg._attn_implementation = attn_implementation + llm_cfg.model_max_length = model_max_length + if model_max_length is not None: + context_length_extension(llm_cfg) + + # Quantization related + quantization_restore_from_checkpoint = False + if kwargs.get("quantize_model_class") is not None: + assert kwargs.get("model_args") is not None + quantize_model_class = kwargs.pop("quantize_model_class", None) + model_args = kwargs.pop("model_args", None) + + if quantize_model_class == "QLlamaForCausalLM": # TODO: Also change the name of this class + from .qllama import QLlamaConfig + + llm_cfg.architectures = "QLlamaForCausalLM" + _attn_implementation = llm_cfg._attn_implementation + llm_cfg = QLlamaConfig(**llm_cfg.to_dict()) + llm_cfg._attn_implementation = _attn_implementation + elif quantize_model_class == "QMemLlamaForCausalLM": # TODO: Also change the name of this class + from .qmemllama import QMemLlamaConfig + + llm_cfg.architectures = "QMemLlamaForCausalLM" + llm_cfg = QMemLlamaConfig(**llm_cfg.to_dict()) + elif quantize_model_class == "FP8LinearQwen2ForCausalLM": + from .configuration_quantize import QuantizationConfig + from .fp8linearqwen2 import FP8LinearQwen2Config + + llm_cfg.architectures = "FP8LinearQwen2ForCausalLM" + coat_fp8_args = QuantizationConfig(**asdict(model_args)) + + # Remove the quantization args from llm_cfg and make it a independent config + model_args_dict = asdict(model_args) + for key in asdict(coat_fp8_args).keys(): + model_args_dict.pop(key, None) + + llm_cfg.coat_fp8_args = asdict(coat_fp8_args) + _attn_implementation = llm_cfg._attn_implementation + + llm_cfg = FP8LinearQwen2Config(**llm_cfg.to_dict()) + llm_cfg._attn_implementation = _attn_implementation + + elif quantize_model_class == "FP8ActivationQwen2ForCausalLM": + from ..coat.activation.models._fp8_quantization_config import QuantizationConfig + from .fp8activationqwen2 import FP8ActivationQwen2Config + + quantization_restore_from_checkpoint = True + + llm_cfg.architectures = "FP8ActivationQwen2ForCausalLM" + coat_fp8_args = QuantizationConfig(**asdict(model_args)) + + # Remove the quantization args from llm_cfg and make it a independent config + model_args_dict = asdict(model_args) + for key in asdict(coat_fp8_args).keys(): + model_args_dict.pop(key, None) + + llm_cfg.coat_fp8_args = asdict(coat_fp8_args) + _attn_implementation = llm_cfg._attn_implementation + + llm_cfg = FP8ActivationQwen2Config(**llm_cfg.to_dict()) + llm_cfg._attn_implementation = _attn_implementation + + elif quantize_model_class == "FP8ActivationResidualQwen2ForCausalLM": + from ..coat.activation.models._fp8_quantization_config import QuantizationConfig + from .fp8activationresidualqwen2 import FP8ActivationResidualQwen2Config + + quantization_restore_from_checkpoint = True + + llm_cfg.architectures = "FP8ActivationResidualQwen2ForCausalLM" + coat_fp8_args = QuantizationConfig(**asdict(model_args)) + + # Remove the quantization args from llm_cfg and make it a independent config + model_args_dict = asdict(model_args) + for key in asdict(coat_fp8_args).keys(): + model_args_dict.pop(key, None) + + llm_cfg.coat_fp8_args = asdict(coat_fp8_args) + _attn_implementation = llm_cfg._attn_implementation + + llm_cfg = FP8ActivationResidualQwen2Config(**llm_cfg.to_dict()) + llm_cfg._attn_implementation = _attn_implementation + else: + raise ValueError(f"{quantize_model_class} is not supported quantize_model_class.") + + kwargs.pop("quantize_model_class", None) + + if quantize_model_class in [ + "FP8LinearQwen2ForCausalLM", + "FP8ActivationQwen2ForCausalLM", + "FP8ActivationResidualQwen2ForCausalLM", + ]: # Remove the quantization args from llm_cfg and make it a independent config + llm_cfg.update(model_args_dict) + else: + llm_cfg.update(asdict(model_args)) + # print(model_args) + + if quantization_restore_from_checkpoint: + fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None) + + llm = AutoModelForCausalLM.from_pretrained( + model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs + ) + + else: + llm = AutoModelForCausalLM.from_pretrained( + model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs + ) + packing.patch(llm) + + # Locate the tokenizer. + llm_path = model_name_or_path + if not has_tokenizer(llm_path): + llm_path = osp.join(llm_path, "llm") + if not has_tokenizer(llm_path): + raise ValueError(f"Cannot find tokenizer in {llm_path}.") + + tokenizer = AutoTokenizer.from_pretrained(llm_path, padding_side="right", use_fast=True, legacy=False) + if model_max_length is not None: + tokenizer.model_max_length = model_max_length + + # Load chat template if specified. + if getattr(config, "chat_template", None) is not None: + logger.info(f"Using chat template: {config.chat_template}") + fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja") + with open(fpath) as fd: + chat_template = fd.read() + tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "") + + # Set stop tokens for the tokenizer + tokenizer.stop_tokens = infer_stop_tokens(tokenizer) + tokenizer.stop_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.stop_tokens) + + # Add media tokens to the tokenizer + tokenizer.media_tokens = MEDIA_TOKENS + tokenizer.media_token_ids = {} + for name, token in MEDIA_TOKENS.items(): + tokenizer.add_tokens([token], special_tokens=True) + tokenizer.media_token_ids[name] = tokenizer.convert_tokens_to_ids(token) + + # TODO(ligeng): is this necessary for llava? + config.hidden_size = llm.config.hidden_size + return llm, tokenizer diff --git a/llava/model/language_model/chat_templates/mistral.jinja b/llava/model/language_model/chat_templates/mistral.jinja new file mode 100644 index 0000000000000000000000000000000000000000..774073e68db8dc9ffc65daab8d948c0079235a5d --- /dev/null +++ b/llava/model/language_model/chat_templates/mistral.jinja @@ -0,0 +1,11 @@ +{{ bos_token }} + +{% for message in messages if message['content'] is not none %} + {% if message['role'] == 'system' %} + {{ message['content'] | trim + '\n\n' }} + {% elif message['role'] == 'user' %} + {{ '[INST] ' + message['content'] | trim + ' [/INST]' }} + {% elif message['role'] == 'assistant' %} + {{ ' ' + message['content'] | trim + eos_token }} + {% endif %} +{% endfor %} diff --git a/llava/model/language_model/chat_templates/qwen2.jinja b/llava/model/language_model/chat_templates/qwen2.jinja new file mode 100644 index 0000000000000000000000000000000000000000..d3b0fd8c09e426c3a8c6e29edde3e56586c961bb --- /dev/null +++ b/llava/model/language_model/chat_templates/qwen2.jinja @@ -0,0 +1,11 @@ +{% if messages[0]['role'] != 'system' %} + {{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }} +{% endif %} + +{% for message in messages if message['content'] is not none %} + {{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }} +{% endfor %} + +{% if add_generation_prompt %} + {{ '<|im_start|>assistant\n' }} +{% endif %} diff --git a/llava/model/language_model/configuration_quantize.py b/llava/model/language_model/configuration_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..9b226923e1ab06e6f868455e72dd8430386c8ac8 --- /dev/null +++ b/llava/model/language_model/configuration_quantize.py @@ -0,0 +1,77 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +from dataclasses import dataclass + +from transformers import PretrainedConfig + + +@dataclass +class QuantizationConfig: + quantize_model: str = "false" + symm: bool = True + epsilon: float = 1e-10 + fabit: str = "E4M3" + fwbit: str = "E4M3" + bobit: str = "E5M2" + row_blocksize: int = -1 + col_blocksize: int = -1 + qchoice: str = "none" + pad_to_multiple_of: int = 0 + + def __init__( + self, + quantize_model, + symm, + epsilon, + fabit, + fwbit, + bobit, + row_blocksize, + col_blocksize, + qchoice, + pad_to_multiple_of, + **kwargs, + ): + super().__init__() + self.quantize_model = quantize_model + self.symm = symm + self.epsilon = epsilon + self.fabit = fabit + self.fwbit = fwbit + self.bobit = bobit + self.row_blocksize = row_blocksize + self.col_blocksize = col_blocksize + self.qchoice = qchoice + self.pad_to_multiple_of = pad_to_multiple_of + + +# class QuantizationConfig(PretrainedConfig): +# def __init__( +# self, +# quantize_model="false", +# symm=True, +# epsilon=1e-10, +# fabit="E4M3", +# fwbit="E4M3", +# bobit="E5M2", +# row_blocksize=-1, +# col_blocksize=-1, +# qchoice="none", +# pad_to_multiple_of=0, +# **kwargs, +# ): +# super().__init__() +# self.quantize_model = quantize_model +# self.symm = symm +# self.epsilon = epsilon +# self.fabit = fabit +# self.fwbit = fwbit +# self.bobit = bobit +# self.row_blocksize = row_blocksize +# self.col_blocksize = col_blocksize +# self.qchoice = qchoice +# self.pad_to_multiple_of = pad_to_multiple_of diff --git a/llava/model/language_model/fp8_qwen2_convert_from_hf.py b/llava/model/language_model/fp8_qwen2_convert_from_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..71a7f9f602491851c21cf7a7702a7174b414bfee --- /dev/null +++ b/llava/model/language_model/fp8_qwen2_convert_from_hf.py @@ -0,0 +1,73 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +import argparse +import os +from dataclasses import asdict, dataclass, field +from typing import Optional + +import torch +import transformers +from transformers import AutoConfig, AutoModelForCausalLM + +from ..coat.activation.models._fp8_quantization_config import QuantizationConfig +from .fp8activationqwen2 import FP8ActivationQwen2Config, make_state_dict_compatible + + +@dataclass +class ConvertArguments: + model_name: str = field(metadata={"help": "The model name or path to download the LLaMA model"}) + save_path: str = field(metadata={"help": "The path where the converted model weights will be saved"}) + cache_dir: str = field(default=None, metadata={"help": "Directory to cache the model"}) + + +def download_and_convert_qwen2(convert_args: ConvertArguments, quantization_args: QuantizationConfig): + """ + Downloads a LLaMA model, converts its weights using `make_state_dict_compatible`, + and saves the converted model. + + Args: + model_name (str): The model name or path to download the LLaMA model. + save_path (str): The path where the converted model weights will be saved. + cache_dir (Optional[str]): Directory to cache the model. Defaults to None. + + Returns: + None + """ + model_name = convert_args.model_name + save_path = convert_args.save_path + cache_dir = convert_args.cache_dir + + # Step 1: Download the original LLaMA model + model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) + + # Step 2: Initialize the model configuration for FP8 or other custom config + config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir) + + # Step 3: Apply make_state_dict_compatible to convert weights + compatible_state_dict = make_state_dict_compatible(model.state_dict()) + + # Step 4: Create a new model instance with compatible configuration + fp8_config = FP8ActivationQwen2Config(**config.to_dict()) + fp8_config.coat_fp8_args = asdict(quantization_args) + fp8_config._name_or_path = save_path + + converted_model = AutoModelForCausalLM.from_config(fp8_config, torch_dtype=torch.bfloat16) + converted_model.load_state_dict(compatible_state_dict) + + # Step 5: Save the converted model and configuration using save_pretrained + os.makedirs(save_path, exist_ok=True) + converted_model.save_pretrained(save_path) + print(f"Converted model saved at {save_path}") + + +if __name__ == "__main__": + # Parse command-line arguments + parser = transformers.HfArgumentParser((ConvertArguments, QuantizationConfig)) # NOTE: FP8 + convert_args, quantization_args = parser.parse_args_into_dataclasses() + + # Call the function with parsed arguments + download_and_convert_qwen2(convert_args, quantization_args) diff --git a/llava/model/language_model/fp8activationqwen2.py b/llava/model/language_model/fp8activationqwen2.py new file mode 100644 index 0000000000000000000000000000000000000000..34bde452da53042dc55b39ae204e541e1eb1f920 --- /dev/null +++ b/llava/model/language_model/fp8activationqwen2.py @@ -0,0 +1,1722 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2 model.""" + +import math +import os +from dataclasses import asdict, dataclass, field +from fnmatch import fnmatch +from functools import lru_cache +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config +from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2Attention, + Qwen2DecoderLayer, + Qwen2FlashAttention2, + Qwen2ForCausalLM, + Qwen2MLP, + Qwen2Model, + Qwen2PreTrainedModel, + Qwen2RMSNorm, + Qwen2RotaryEmbedding, + Qwen2SdpaAttention, + apply_rotary_pos_emb, + repeat_kv, + rotate_half, +) +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) + +# FP8 related +from ..coat.activation.models._fp8_quantization_config import QuantizationConfig +from ..coat.activation.models._fp8_weightcache import FP8CacheWeightModule +from ..coat.activation.models._fp8manager import FP8Manager +from ..coat.activation.real_quantization import ( + Coat_quantize_bgn, + Coat_quantize_end, + fp8_add_Ifp_Ifp_Ofp_Og16, + fp8_add_Ifp_Ifp_Ofp_Opt, + fp8_division, + fp8_division_transpose, + fp8_gelu_backward, + fp8_gelu_forward, + fp8_layernorm_noparam_backward, + fp8_layernorm_noparam_forward, + fp8_linear_backward, + fp8_linear_forward, + fp8_mul_backward, + fp8_mul_forward, + fp8_quantize, + fp8_quantize_pertensor, + fp8_quantize_pertensor_transpose, + fp8_rmsnorm_backward, + fp8_rmsnorm_forward, + fp8_silu_backward, + fp8_silu_forward, + fp8_transpose, +) +from ..liger.cross_entropy import LigerForCausalLMLoss +from ..qlinear_te import QLinearTE + +# from .configuration_quantize import QuantizationConfig + +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward + +logger = logging.get_logger(__name__) + + +class FP8ActivationQwen2Config(Qwen2Config): + model_type = "fp8activation_qwen2" + + def __init__( + self, + coat_fp8_args=None, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + **kwargs, + ): + super().__init__( + vocab_size, + hidden_size, + intermediate_size, + num_hidden_layers, + num_attention_heads, + num_key_value_heads, + hidden_act, + max_position_embeddings, + initializer_range, + rms_norm_eps, + use_cache, + tie_word_embeddings, + rope_theta, + rope_scaling, + use_sliding_window, + sliding_window, + max_window_layers, + attention_dropout, + **kwargs, + ) + + self.coat_fp8_args = coat_fp8_args + + +class FP8ActivationQwen2BeforeAttentionResidual(FP8CacheWeightModule): + """ + This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 1 * Linear layers + """ + + def __init__(self, config: FP8ActivationQwen2Config, qargs: QuantizationConfig, layer_idx: Optional[int] = None): + super().__init__(config, qargs, layer_idx) + + self.qargs = qargs + self.fwobits = { + "fabit": self.qargs.fabit, + "fwbit": self.qargs.fwbit, + "fobit": self.qargs.fobit, + "babit": self.qargs.babit, + "bwbit": self.qargs.bwbit, + "bobit": self.qargs.bobit, + } + + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + + def forward(self, re_x, x, s, rmsnorm_weight): + if self.training: + if self.qargs.weight_memory_efficient: + # Prepare + with torch.no_grad(): + if FP8Manager.is_first_microbatch: + # Directly use the corresponding weight + weight1, _, weight1_s = self.prepare_weight( + self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch + ) + weight2, _, weight2_s = self.prepare_weight( + self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch + ) + weight3, _, weight3_s = self.prepare_weight( + self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch + ) + else: + weight1_s = self.prepare_weight(self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch) + weight2_s = self.prepare_weight(self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch) + weight3_s = self.prepare_weight(self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch) + + weight1, weight2, weight3 = None, None, None + return _FP8ActivationQwen2BeforeAttentionResidual.apply( + re_x, + x, + s, + self.q_proj.weight, + weight1, + None, + weight1_s, + self.q_proj.bias, + self.k_proj.weight, + weight2, + None, + weight2_s, + self.k_proj.bias, + self.v_proj.weight, + weight3, + None, + weight3_s, + self.v_proj.bias, + rmsnorm_weight, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + # Prepare + with torch.no_grad(): + weight1, weight1_t, weight1_s = self.prepare_weight( + self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch + ) + weight2, weight2_t, weight2_s = self.prepare_weight( + self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch + ) + weight3, weight3_t, weight3_s = self.prepare_weight( + self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch + ) + return _FP8ActivationQwen2BeforeAttentionResidual.apply( + re_x, + x, + s, + self.q_proj.weight, + weight1, + weight1_t, + weight1_s, + self.k_proj.weight, + weight2, + weight2_t, + weight2_s, + self.v_proj.weight, + weight3, + weight3_t, + weight3_s, + rmsnorm_weight, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + raise NotImplementedError("This should be implemented in the future") + return re_x, self.att_proj(self.attn_norm(re_x)) + + +class _FP8ActivationQwen2BeforeAttentionResidual(torch.autograd.Function): + @staticmethod + def forward( + ctx, + re_x, + in_x, + in_s, + weight1_origin, + weight1, + weight1_t, + weight1_s, + weight1_bias, + weight2_origin, + weight2, + weight2_t, + weight2_s, + weight2_bias, + weight3_origin, + weight3, + weight3_t, + weight3_s, + weight3_bias, + rmsnorm_weight, + group_size, + fwobits, + layer_id, + config, + qargs, + eps=1e-5, + ): + # for autograd + if fwobits["fabit"] == "E4M3": + # in_x = in_x.to(torch.float8_e4m3fn) + in_x = in_x.view(torch.float8_e4m3fn) + else: + raise ValueError("fabit should be E4M3") + + # LayerNorm + ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward( + in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True + ) + + # Linear Layer QKV Projection + if qargs.weight_memory_efficient: + assert ( + weight1_t is None and weight2_t is None and weight3_t is None + ) # we should not pass W^T to here if weight memory efficient + if weight1 is None: + assert weight2 is None and weight3 is None + weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s) + weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s) + weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s) + + fc1_x = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, False, group_size, bias=weight1_bias) # query states + fc2_x = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, False, group_size, bias=weight2_bias) # key states + fc3_x = fp8_linear_forward(ln_x, ln_s, weight3, weight3_s, False, group_size, bias=weight3_bias) # value states + + # ==================== save for backward ==================== + ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s) + if qargs.weight_memory_efficient: + assert weight1_t is None and weight2_t is None and weight3_t is None + ctx.weight = weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s + else: + ctx.weight = weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s + ctx.bias = weight1_bias, weight2_bias, weight3_bias + + ctx.group_size = group_size + ctx.ln_utils = ln_utils + ctx.utils = fwobits, layer_id, config, qargs + + return re_x, fc1_x, fc2_x, fc3_x + + @staticmethod + def backward(ctx, fp_grad, query_g, key_g, value_g): + in_x, in_s, ln_x_t, ln_s = ctx.saved_tensors + weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s = ctx.weight + weight1_bias, weight2_bias, weight3_bias = ctx.bias + + group_size = ctx.group_size + rms_weight, rstd, num_warps = ctx.ln_utils + fwobits, layer_id, config, qargs = ctx.utils + + # ==================== Begin backward ==================== + # Gradient of Bias TODO: make this better + if weight1_bias is not None and weight2_bias is not None and weight3_bias is not None: + att_q_bg = query_g.reshape(-1, query_g.shape[-1]).sum(0) + att_k_bg = key_g.reshape(-1, key_g.shape[-1]).sum(0) + att_v_bg = value_g.reshape(-1, value_g.shape[-1]).sum(0) + else: + att_q_bg = None + att_k_bg = None + att_v_bg = None + + # Quantize the RoPE and FlashAttention Output. grad_input and grad_weight requires different data layout. + query_g, query_gs, query_g_t = fp8_quantize_pertensor_transpose( + query_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False + ) + key_g, key_gs, key_g_t = fp8_quantize_pertensor_transpose( + key_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False + ) + value_g, value_gs, value_g_t = fp8_quantize_pertensor_transpose( + value_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False + ) + + # Linear Layer QKV Projection + if qargs.weight_memory_efficient: + weight1_t, weight1_s = fp8_division_transpose( + weight1_t, qargs.group_size, fwobits["fwbit"], weight1_s, only_transposed=True + ) + weight2_t, weight2_s = fp8_division_transpose( + weight2_t, qargs.group_size, fwobits["fwbit"], weight2_s, only_transposed=True + ) + weight3_t, weight3_s = fp8_division_transpose( + weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True + ) + + fc1_g1, att_q_wg = fp8_linear_backward( + ln_x_t, ln_s, query_g, query_gs, query_g_t, weight1_t, weight1_s, group_size + ) + fc1_g2, att_k_wg = fp8_linear_backward(ln_x_t, ln_s, key_g, key_gs, key_g_t, weight2_t, weight2_s, group_size) + fc1_g3, att_v_wg = fp8_linear_backward( + ln_x_t, ln_s, value_g, value_gs, value_g_t, weight3_t, weight3_s, group_size + ) + + fc1_g = fc1_g1 + fc1_g2 + fc1_g3 + + # LayerNorm + in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc1_g, rms_weight, rstd, group_size, num_warps) + + # Add the gradient together, and prepare the input of the next layer. + re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt( + fp_grad, in_g, group_size, fwobits["babit"], stochastic=False + ) + + # for autograd. forward's data type should be the same of backward tensor. this will not change the actual binary representation. + in_g = in_g.view(torch.float8_e4m3fn) + + # Although the next operator is a linear layer in MLPResidual module, we return in_sg_g16 to make the size compatible with the forward. Otherwise it will not pass autograd. + return ( + re_g, + in_g, + in_sg_g16, + att_q_wg, + None, + None, + None, + att_q_bg, + att_k_wg, + None, + None, + None, + att_k_bg, + att_v_wg, + None, + None, + None, + att_v_bg, + rms_weight_grad, + None, + None, + None, + None, + None, + None, + ) + + +class FP8ActivationQwen2AfterAttentionResidual(FP8CacheWeightModule): + """ + This is a typical transformer attention module that contains (1) Residual (2) 1 * Linear layers + """ + + def __init__(self, config: FP8ActivationQwen2Config, qargs: QuantizationConfig, layer_id): + super().__init__(config, qargs, layer_id) + + self.qargs = qargs + self.fwobits = { + "fabit": self.qargs.fabit, + "fwbit": self.qargs.fwbit, + "fobit": self.qargs.fobit, + "babit": self.qargs.babit, + "bwbit": self.qargs.bwbit, + "bobit": self.qargs.bobit, + } + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def forward(self, re_x, in_x): + if self.training: + if self.qargs.weight_memory_efficient: + # prepare for the weight + with torch.no_grad(): + if FP8Manager.is_first_microbatch: + weight4, _, weight4_s = self.prepare_weight( + self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch + ) + else: + weight4_s = self.prepare_weight(self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch) + weight1 = None + + return _FP8ActivationQwen2AfterAttentionResidual.apply( + re_x, + in_x, + self.o_proj.weight, + weight4, + None, + weight4_s, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + # prepare for the weight + with torch.no_grad(): + weight4, weight4_t, weight4_s = self.prepare_weight( + self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch + ) + + return _FP8ActivationQwen2AfterAttentionResidual.apply( + re_x, + in_x, + self.o_proj.weight, + weight4, + weight4_t, + weight4_s, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + return re_x + self.attn_out(in_x), None, None + + +class _FP8ActivationQwen2AfterAttentionResidual(torch.autograd.Function): + @staticmethod + def forward( + ctx, re_x, flash_x, weight4_origin, weight4, weight4_t, weight4_s, group_size, fwobits, layer_id, config, qargs + ): + time_bench = os.getenv("TIME_BENCH") + + if time_bench: + start_1 = torch.cuda.Event(enable_timing=True) + start_1.record() + + # Quantize the FlashAttention Output + flash_qx, flash_s, _ = fp8_quantize_pertensor( + flash_x, group_size, fwobits["fabit"] + ) # Modified to make it memory efficient + + if time_bench: + end_1 = torch.cuda.Event(enable_timing=True) + end_1.record() + start_2 = torch.cuda.Event(enable_timing=True) + start_2.record() + + # # Attention Projection Linear Layer + if qargs.weight_memory_efficient: + assert weight4_t is None + if weight4 is None: # the second batch + weight4, weight4_s = fp8_division(weight4_origin, qargs.group_size, fwobits["fwbit"], weight4_s) + + if time_bench: + end_2 = torch.cuda.Event(enable_timing=True) + end_2.record() + start_3 = torch.cuda.Event(enable_timing=True) + start_3.record() + + fc4_x = fp8_linear_forward(flash_qx, flash_s, weight4, weight4_s, False, group_size) # + + if time_bench: + end_3 = torch.cuda.Event(enable_timing=True) + end_3.record() + start_4 = torch.cuda.Event(enable_timing=True) + start_4.record() + + # import IPython + # IPython.embed() + # Add the activations together + fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc4_x, flash_qx.dtype, group_size) + + if time_bench: + end_4 = torch.cuda.Event(enable_timing=True) + end_4.record() + + torch.cuda.synchronize() + if int(os.environ.get("LOCAL_RANK")) == 0: + print( + f"[AfterAt] Part 1: {start_1.elapsed_time(end_1):.6f} ms | " + f" Part 2: {start_2.elapsed_time(end_2):.6f} ms | " + f" Part 3: {start_3.elapsed_time(end_3):.6f} ms | " + f" Part 4: {start_4.elapsed_time(end_4):.6f} ms | " + f" Input shape: {re_x.shape}" + ) + + # ==================== save for backward ==================== + ctx.save_for_backward(flash_x, flash_s) + if qargs.weight_memory_efficient: + assert weight4_t is None + ctx.weight = weight4_origin, weight4_s + else: + ctx.weight = weight4_t, weight4_s + ctx.group_size = group_size + ctx.fwobits = fwobits + ctx.utils = fwobits, layer_id, config, qargs + + # For autograd + out_x = out_x.view(torch.float8_e4m3fn) + + return fp_x, out_x, out_s + + @staticmethod + def backward(ctx, fp_grad, out_g, out_gs): + flash_x, flash_s = ctx.saved_tensors + weight4_t, weight4_s = ctx.weight + group_size = ctx.group_size + fwobits = ctx.fwobits + fwobits, layer_id, config, qargs = ctx.utils + + # for autograd + if fwobits["babit"] == "E5M2": + # out_g = out_g.to(torch.float8_e5m2) + out_g = out_g.view(torch.float8_e5m2) + else: + raise ValueError("babit should be E5M2") + out_gs_max = out_gs.max() + + # ==================== Begin backward ==================== + # Output Projection + out_g_t = fp8_transpose(out_g, transpose_output_2d=True) + + # We do not save an extra flash_x to save the memory usage + flash_x_t, flash_s = fp8_division_transpose( + flash_x, group_size, fwobits["fabit"], flash_s, stochastic=False, only_transposed=True + ) + + if qargs.weight_memory_efficient: + weight4_t, weight4_s = fp8_division_transpose( + weight4_t, qargs.group_size, fwobits["fwbit"], weight4_s, only_transposed=True + ) + fc4_g, attn_out_wg = fp8_linear_backward( + flash_x_t, flash_s, out_g, out_gs_max, out_g_t, weight4_t, weight4_s, group_size + ) + + return fp_grad, fc4_g, attn_out_wg, None, None, None, None, None, None, None, None + + +class FP8ActivationQwen2MLPResidual(FP8CacheWeightModule): + """ + This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 2 / 3 * Linear layers + (4) GELU / Silu Activation + """ + + def __init__(self, config: FP8ActivationQwen2Config, qargs: QuantizationConfig, layer_id, hidden_size: int): + super().__init__(config, qargs, layer_id) + + self.qargs = qargs + self.fwobits = { + "fabit": self.qargs.fabit, + "fwbit": self.qargs.fwbit, + "fobit": self.qargs.fobit, + "babit": self.qargs.babit, + "bwbit": self.qargs.bwbit, + "bobit": self.qargs.bobit, + } + + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.training = True + + # below is only used when training = False + assert config.hidden_act == "silu", "We only support silu activation currently" + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, re_x, x, s, rmsnorm_weight): + if self.training: + if self.qargs.weight_memory_efficient: # prepare for the weight + with torch.no_grad(): + if FP8Manager.is_first_microbatch: + # Directly use the corresponding weight + weight1, _, weight1_s = self.prepare_weight( + self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch + ) + weight2, _, weight2_s = self.prepare_weight( + self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch + ) + weight3, _, weight3_s = self.prepare_weight( + self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch + ) + else: + weight1_s = self.prepare_weight( + self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch + ) + weight2_s = self.prepare_weight(self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch) + weight3_s = self.prepare_weight( + self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch + ) + + weight1, weight2, weight3 = None, None, None + return _FP8ActivationQwen2MLPResidual.apply( + re_x, + x, + s, + self.gate_proj.weight, + weight1, + None, + weight1_s, + self.up_proj.weight, + weight2, + None, + weight2_s, + self.down_proj.weight, + weight3, + None, + weight3_s, + rmsnorm_weight, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + # prepare for the weight + with torch.no_grad(): + weight1, weight1_t, weight1_s = self.prepare_weight( + self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch + ) + weight2, weight2_t, weight2_s = self.prepare_weight( + self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch + ) + weight3, weight3_t, weight3_s = self.prepare_weight( + self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch + ) + + return _FP8ActivationQwen2MLPResidual.apply( + re_x, + x, + s, + self.gate_proj.weight, + weight1, + weight1_t, + weight1_s, + self.up_proj.weight, + weight2, + weight2_t, + weight2_s, + self.down_proj.weight, + weight3, + weight3_t, + weight3_s, + rmsnorm_weight, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + raise NotImplementedError("Need TODO") + og_x = re_x + re_x = self.ff_norm(re_x) + re_x = self.ff_proj(re_x) + re_x = self.act(re_x) + re_x = self.ff_out(re_x) + re_x = og_x + re_x + return re_x, None, None + + +class _FP8ActivationQwen2MLPResidual(torch.autograd.Function): + @staticmethod + def forward( + ctx, + re_x, + in_x, + in_s, + weight1_origin, + weight1, + weight1_t, + weight1_s, + weight2_origin, + weight2, + weight2_t, + weight2_s, + weight3_origin, + weight3, + weight3_t, + weight3_s, + rmsnorm_weight, + group_size, + fwobits, + layer_id, + config, + qargs, + eps=1e-5, + ): + # For autograd + if fwobits["fabit"] == "E4M3": + # in_x = in_x.to(torch.float8_e4m3fn) + in_x = in_x.view(torch.float8_e4m3fn) + else: + raise ValueError("fabit should be E4M3") + + # LayerNorm + ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward( + in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True + ) + + # Linear Layer of Up Projection and Gate Projection. They are fused as one linear layer. + if qargs.weight_memory_efficient: + assert weight1_t is None and weight2_t is None and weight3_t is None # memory efficient + if weight1 is None: + assert weight2 is None and weight3 is None + weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s) + weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s) + weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s) + + gate_x, gate_s = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, True, group_size) # Gate Proj + up_x, up_s = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, True, group_size) # Up Proj + + # silu Activation + silu_x, silu_s = fp8_silu_forward(gate_x, gate_s, group_size) + + # Element-wise Multiplication + mul_x, mul_s, mul_x_t = fp8_mul_forward(silu_x, silu_s, up_x, up_s, group_size, transpose_output_2d=True) + + # Output Projection + if weight3 is None: # memory efficient + weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s) + fc3_x = fp8_linear_forward(mul_x, mul_s, weight3, weight3_s, False, group_size) + + # Add the activation together + fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc3_x, mul_x.dtype, group_size) + + # ==================== save for backward ==================== + ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s) + + ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s) + if ( + qargs.weight_memory_efficient + ): # Weight_1/2_origin will not be saved twice, so it will be more memory efficient. + assert weight1_t is None and weight2_t is None and weight3_t is None + ctx.weight = (weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s) + else: # Weight1/2_t is different from the origin weight, so saving it will consumes additional memory footprint. + ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s) + + ctx.group_size = group_size + ctx.ln_utils = ln_utils + ctx.utils = fwobits, layer_id, config, qargs + + out_x = out_x.view(torch.float8_e4m3fn) + + return fp_x, out_x, out_s + + @staticmethod + def backward(ctx, fp_grad, out_g, out_gs): + fwobits, layer_id, config, qargs = ctx.utils + + in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s = ctx.saved_tensors + + (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s) = ctx.weight + group_size = ctx.group_size + rms_weight, rstd, num_warps = ctx.ln_utils + fwobits, layer_id, config, qargs = ctx.utils + + # For autograd + if fwobits["babit"] == "E5M2": + # out_g = out_g.to(torch.float8_e5m2) + out_g = out_g.view(torch.float8_e5m2) + else: + raise ValueError("babit should be E5M2") + out_gs_max = out_gs.max() + + # ==================== Begin backward ==================== + # Output Projection + out_gs = out_gs.max() + out_g_t = fp8_transpose(out_g, transpose_output_2d=True) + + # Element-wise Multiplication gradient checkpointing + # silu gradient checkpointing + silu_x, silu_s = fp8_silu_forward(gate_x, gate_s, group_size) + mul_x, mul_s, mul_x_t = fp8_mul_forward(silu_x, silu_s, up_x, up_s, group_size, transpose_output_2d=True) + + if qargs.weight_memory_efficient: + weight3_t, weight3_s = fp8_division_transpose( + weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True + ) + fc3_g, weight3_grad = fp8_linear_backward( + mul_x_t, mul_s, out_g, out_gs_max, out_g_t, weight3_t, weight3_s, group_size + ) + + # [MEM TEST] + del out_g, out_g_t, weight3_t + + # Element-wise Multiplication, 1 means gate, 2 means up + mul_g1, (mul_g2, mul_gs2, mul_g2_t) = fp8_mul_backward( + silu_x, silu_s, up_x, up_s, fc3_g, group_size, fwobits["babit"], output_quantized_transpose=True + ) + + # Silu activation + silu_g, silu_gs, silu_g_t = fp8_silu_backward( + gate_x, gate_s, mul_g1, group_size, fwobits["babit"], output_quantized_transpose=True + ) + + # Linear Layer of Up and Gate Projection + if qargs.weight_memory_efficient: + weight1_t, weight1_s = fp8_division_transpose( + weight1_t, group_size, fwobits["fwbit"], weight1_s, only_transposed=True + ) + weight2_t, weight2_s = fp8_division_transpose( + weight2_t, group_size, fwobits["fwbit"], weight2_s, only_transposed=True + ) + + # Gate Proj + fc1_g, weight1_grad = fp8_linear_backward( + ln_x_t, ln_s, silu_g, silu_gs, silu_g_t, weight1_t, weight1_s, group_size + ) + fc2_g, weight2_grad = fp8_linear_backward( + ln_x_t, ln_s, mul_g2, mul_gs2, mul_g2_t, weight2_t, weight2_s, group_size + ) + + fc_g = fc1_g + fc2_g + + # layerNorm + in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc_g, rms_weight, rstd, group_size, num_warps) + + # Add the gradient together + re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt( + fp_grad, in_g, group_size, fwobits["babit"], stochastic=False + ) + + in_g = in_g.view(torch.float8_e4m3fn) + + return ( + re_g, + in_g, + in_sg_g16, + weight1_grad, + None, + None, + None, + weight2_grad, + None, + None, + None, + weight3_grad, + None, + None, + None, + rms_weight_grad, + None, + None, + None, + None, + None, + None, + ) + + +class FP8ActivationQwen2AttentionWithoutLinear(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.rotary_emb = Qwen2RotaryEmbedding(config=self.config) + + def forward( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class FP8ActivationQwen2FlashAttention2WithoutLinear(FP8ActivationQwen2AttentionWithoutLinear): + """ + Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ): + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + kv_seq_len = key_states.shape[-2] + cache_position[0] + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + else: + sliding_window = None + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class FP8ActivationQwen2SdpaAttentionWithoutLinear(FP8ActivationQwen2AttentionWithoutLinear): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Qwen2Attention.forward + def forward( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + query_states=query_states, + key_states=key_states, + value_states=value_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + return attn_output, None, past_key_value + + +FP8LINEARQWEN2_ATTENTION_CLASSES = { + "eager": FP8ActivationQwen2AttentionWithoutLinear, + "flash_attention_2": FP8ActivationQwen2FlashAttention2WithoutLinear, + "sdpa": FP8ActivationQwen2SdpaAttentionWithoutLinear, +} + + +class FP8ActivationQwen2DecoderLayer(nn.Module): + def __init__(self, config: FP8ActivationQwen2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = FP8LINEARQWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.qargs = QuantizationConfig(**config.coat_fp8_args) + self.BeforeAttention = FP8ActivationQwen2BeforeAttentionResidual(config, self.qargs, layer_idx) + self.AfterAttention = FP8ActivationQwen2AfterAttentionResidual(config, self.qargs, layer_idx) + self.MLPResidual = FP8ActivationQwen2MLPResidual(config, self.qargs, layer_idx, self.hidden_size) + + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + quant_hidden_states: torch.Tensor, + scale_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + time_bench = os.getenv("TIME_BENCH") + + if time_bench: + start_1 = torch.cuda.Event(enable_timing=True) + start_1.record() + + # Coat: The residual, LayerNorm, and the Q/K/V Projection Linear Layer + residual, query_states, key_states, value_states = self.BeforeAttention( + hidden_states, quant_hidden_states, scale_hidden_states, self.input_layernorm.weight + ) + + if time_bench: + end_1 = torch.cuda.Event(enable_timing=True) + end_1.record() + start_2 = torch.cuda.Event(enable_timing=True) + start_2.record() + + # Self Attention without any linear layer + hidden_states, self_attn_weights, present_key_value = self.self_attn( + query_states=query_states, + key_states=key_states, + value_states=value_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + if time_bench: + end_2 = torch.cuda.Event(enable_timing=True) + end_2.record() + start_3 = torch.cuda.Event(enable_timing=True) + start_3.record() + + # Coat: The Output Projection Linear Layer and Residual + hidden_states, quant_hidden_states, scale_hidden_states = self.AfterAttention(residual, hidden_states) + + if time_bench: + end_3 = torch.cuda.Event(enable_timing=True) + end_3.record() + start_4 = torch.cuda.Event(enable_timing=True) + start_4.record() + + # Residual Connection, LayerNorm, and the whole MLP module + hidden_states, quant_hidden_states, scale_hidden_states = self.MLPResidual( + hidden_states, quant_hidden_states, scale_hidden_states, self.post_attention_layernorm.weight + ) + + if time_bench: + end_4 = torch.cuda.Event(enable_timing=True) + end_4.record() + + torch.cuda.synchronize() + if int(os.environ.get("LOCAL_RANK")) == 0: + print( + f"[Forward] Part 1: {start_1.elapsed_time(end_1):.6f} ms | " + f" Part 2: {start_2.elapsed_time(end_2):.6f} ms | " + f" Part 3: {start_3.elapsed_time(end_3):.6f} ms | " + f" Part 4: {start_4.elapsed_time(end_4):.6f} ms | " + f" Input shape: {hidden_states.shape}" + ) + + outputs = ((hidden_states, quant_hidden_states, scale_hidden_states),) + + # if int(os.environ.get("LOCAL_RANK")) == 0: + # import IPython + # IPython.embed() + # else: + # import time + # time.sleep(1000) + # if output_attentions: + # outputs += (self_attn_weights,) + + # if use_cache: + # outputs += (present_key_value,) + + return outputs + + +class FP8ActivationQwen2PreTrainedModel(Qwen2PreTrainedModel): + config_class = FP8ActivationQwen2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["FP8ActivationQwen2DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class FP8ActivationQwen2Model(FP8ActivationQwen2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: Qwen2Config + """ + + def __init__(self, config: FP8ActivationQwen2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [FP8ActivationQwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2RotaryEmbedding(config=config) + + # Quantize + self.qargs = QuantizationConfig(**config.coat_fp8_args) + self.quantize_input_before_block = Coat_quantize_bgn(self.qargs) + self.quantize_output_after_block = Coat_quantize_end(self.qargs) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + self.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + # import warnings + + # # 将所有 UserWarning 提升为异常 + # warnings.simplefilter("error", UserWarning) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + # Prepare the input for Coat decoderlayer + hidden_states, quant_hidden_states, scale_hidden_states = self.quantize_input_before_block(hidden_states) + + for layer_idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # if int(os.environ.get("LOCAL_RANK")) == 0: + # print(f"FP8 Layer Idx: {layer_idx}, Input Shape: {hidden_states.shape}, Memory Allocated: {torch.cuda.memory_allocated() // 1024 ** 2} MB | Peak: {torch.cuda.max_memory_allocated() // 1024 ** 2} MB") + + # NOTE: We explicitly force the LLM do not use Gradient Checkpointing + # exit(0) + # if self.gradient_checkpointing and self.training: + # layer_outputs = self._gradient_checkpointing_func( + # decoder_layer.__call__, + # hidden_states, + # quant_hidden_states, + # scale_hidden_states, + # causal_mask, + # position_ids, + # past_key_values, + # output_attentions, + # use_cache, + # cache_position, + # position_embeddings, + # ) + # else: + layer_outputs = decoder_layer( + hidden_states, + quant_hidden_states, + scale_hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states, quant_hidden_states, scale_hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # Summarize the output of the Decoder Layer + hidden_states = self.quantize_output_after_block(hidden_states, quant_hidden_states, scale_hidden_states) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + _update_causal_mask = Qwen2Model._update_causal_mask + + +class FP8ActivationQwen2ForCausalLM(FP8ActivationQwen2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = FP8ActivationQwen2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @property + @lru_cache + def loss_function(self): + return LigerForCausalLMLoss + + forward = Qwen2ForCausalLM.forward + + +AutoConfig.register("fp8activation_qwen2", FP8ActivationQwen2Config) +AutoModel.register(FP8ActivationQwen2Config, FP8ActivationQwen2Model) +AutoModelForCausalLM.register(FP8ActivationQwen2Config, FP8ActivationQwen2ForCausalLM) + + +def make_state_dict_compatible(state_dict: dict[str, torch.Tensor]): + compatible_state_dict = {} + + for key, value in state_dict.items(): + if fnmatch(key, "*self_attn.q_proj*"): + new_key = key.replace("self_attn.q_proj", "BeforeAttention.q_proj") + elif fnmatch(key, "*self_attn.k_proj*"): + new_key = key.replace("self_attn.k_proj", "BeforeAttention.k_proj") + elif fnmatch(key, "*self_attn.v_proj*"): + new_key = key.replace("self_attn.v_proj", "BeforeAttention.v_proj") + elif fnmatch(key, "*self_attn.o_proj*"): + new_key = key.replace("self_attn.o_proj", "AfterAttention.o_proj") + + elif fnmatch(key, "*mlp.gate_proj*"): + new_key = key.replace("mlp.gate_proj", "MLPResidual.gate_proj") + elif fnmatch(key, "*mlp.up_proj*"): + new_key = key.replace("mlp.up_proj", "MLPResidual.up_proj") + elif fnmatch(key, "*mlp.down_proj*"): + new_key = key.replace("mlp.down_proj", "MLPResidual.down_proj") + + else: + new_key = key + + compatible_state_dict[new_key] = value + + return compatible_state_dict diff --git a/llava/model/language_model/fp8activationresidualqwen2.py b/llava/model/language_model/fp8activationresidualqwen2.py new file mode 100644 index 0000000000000000000000000000000000000000..9bccf84346e0e2b59ed6dc56ca1a6799bc11c5c5 --- /dev/null +++ b/llava/model/language_model/fp8activationresidualqwen2.py @@ -0,0 +1,1586 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2 model.""" + +import math +import os +from dataclasses import asdict, dataclass, field +from fnmatch import fnmatch +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config +from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2Attention, + Qwen2DecoderLayer, + Qwen2FlashAttention2, + Qwen2ForCausalLM, + Qwen2MLP, + Qwen2Model, + Qwen2PreTrainedModel, + Qwen2RMSNorm, + Qwen2RotaryEmbedding, + Qwen2SdpaAttention, + apply_rotary_pos_emb, + repeat_kv, + rotate_half, +) +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) + +# FP8 related +from ..coat.activation.models._fp8_quantization_config import QuantizationConfig +from ..coat.activation.models._fp8_weightcache import FP8CacheWeightModule +from ..coat.activation.models._fp8manager import FP8Manager +from ..coat.activation.real_quantization import ( + Coat_quantize_bgn, + Coat_quantize_end, + fp8_add_Ifp_Ifp_Ofp_Og16, + fp8_add_Ifp_Ifp_Ofp_Opt, + fp8_division, + fp8_division_transpose, + fp8_gelu_backward, + fp8_gelu_forward, + fp8_layernorm_noparam_backward, + fp8_layernorm_noparam_forward, + fp8_linear_backward, + fp8_linear_forward, + fp8_mul_backward, + fp8_mul_forward, + fp8_quantize, + fp8_quantize_pertensor, + fp8_quantize_pertensor_transpose, + fp8_rmsnorm_backward, + fp8_rmsnorm_forward, + fp8_silu_backward, + fp8_silu_forward, + fp8_transpose, +) +from ..qlinear_te import QLinearTE +from .configuration_quantize import QuantizationConfig + +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward + +logger = logging.get_logger(__name__) + + +class FP8ActivationResidualQwen2Config(Qwen2Config): + model_type = "fp8activationresidual_qwen2" + + def __init__( + self, + coat_fp8_args=None, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + **kwargs, + ): + super().__init__( + vocab_size, + hidden_size, + intermediate_size, + num_hidden_layers, + num_attention_heads, + num_key_value_heads, + hidden_act, + max_position_embeddings, + initializer_range, + rms_norm_eps, + use_cache, + tie_word_embeddings, + rope_theta, + rope_scaling, + use_sliding_window, + sliding_window, + max_window_layers, + attention_dropout, + **kwargs, + ) + + self.coat_fp8_args = coat_fp8_args + + +class FP8ActivationResidualQwen2BeforeAttentionResidual(FP8CacheWeightModule): + """ + This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 1 * Linear layers + """ + + def __init__( + self, config: FP8ActivationResidualQwen2Config, qargs: QuantizationConfig, layer_idx: Optional[int] = None + ): + super().__init__(config, qargs, layer_idx) + + self.qargs = qargs + self.fwobits = { + "fabit": self.qargs.fabit, + "fwbit": self.qargs.fwbit, + "fobit": self.qargs.fobit, + "babit": self.qargs.babit, + "bwbit": self.qargs.bwbit, + "bobit": self.qargs.bobit, + } + + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + + def forward(self, x, s, rmsnorm_weight): + if self.training: + if self.qargs.weight_memory_efficient: + # Prepare + with torch.no_grad(): + weight1_s = self.prepare_weight(self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch) + weight2_s = self.prepare_weight(self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch) + weight3_s = self.prepare_weight(self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch) + return _FP8ActivationResidualQwen2BeforeAttentionResidual.apply( + x, + s, + self.q_proj.weight, + None, + None, + weight1_s, + self.q_proj.bias, + self.k_proj.weight, + None, + None, + weight2_s, + self.k_proj.bias, + self.v_proj.weight, + None, + None, + weight3_s, + self.v_proj.bias, + rmsnorm_weight, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + # Prepare + with torch.no_grad(): + weight1, weight1_t, weight1_s = self.prepare_weight( + self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch + ) + weight2, weight2_t, weight2_s = self.prepare_weight( + self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch + ) + weight3, weight3_t, weight3_s = self.prepare_weight( + self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch + ) + return _FP8ActivationResidualQwen2BeforeAttentionResidual.apply( + x, + s, + self.q_proj.weight, + weight1, + weight1_t, + weight1_s, + self.k_proj.weight, + weight2, + weight2_t, + weight2_s, + self.v_proj.weight, + weight3, + weight3_t, + weight3_s, + rmsnorm_weight, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + raise NotImplementedError("This should be implemented in the future") + return re_x, self.att_proj(self.attn_norm(re_x)) + + +class _FP8ActivationResidualQwen2BeforeAttentionResidual(torch.autograd.Function): + @staticmethod + def forward( + ctx, + in_x, + in_s, + weight1_origin, + weight1, + weight1_t, + weight1_s, + weight1_bias, + weight2_origin, + weight2, + weight2_t, + weight2_s, + weight2_bias, + weight3_origin, + weight3, + weight3_t, + weight3_s, + weight3_bias, + rmsnorm_weight, + group_size, + fwobits, + layer_id, + config, + qargs, + eps=1e-5, + ): + # for autograd + if fwobits["fabit"] == "E4M3": + # in_x = in_x.to(torch.float8_e4m3fn) + in_x = in_x.view(torch.float8_e4m3fn) + else: + raise ValueError("fabit should be E4M3") + + # LayerNorm + ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward( + in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True + ) + + # Linear Layer QKV Projection + if qargs.weight_memory_efficient: + assert weight1 is None # memory efficient + weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s) + weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s) + weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s) + + fc1_x = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, False, group_size, bias=weight1_bias) # query states + fc2_x = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, False, group_size, bias=weight2_bias) # key states + fc3_x = fp8_linear_forward(ln_x, ln_s, weight3, weight3_s, False, group_size, bias=weight3_bias) # value states + + # ==================== save for backward ==================== + ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s) + if qargs.weight_memory_efficient: + assert weight1_t is None and weight2_t is None and weight3_t is None + ctx.weight = weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s + else: + ctx.weight = weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s + ctx.bias = weight1_bias, weight2_bias, weight3_bias + + ctx.group_size = group_size + ctx.ln_utils = ln_utils + ctx.utils = fwobits, layer_id, config, qargs + + return in_x, in_s, fc1_x, fc2_x, fc3_x + + @staticmethod + def backward(ctx, q_grad, s_grad, query_g, key_g, value_g): + in_x, in_s, ln_x_t, ln_s = ctx.saved_tensors + weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s = ctx.weight + weight1_bias, weight2_bias, weight3_bias = ctx.bias + + group_size = ctx.group_size + rms_weight, rstd, num_warps = ctx.ln_utils + fwobits, layer_id, config, qargs = ctx.utils + + # ==================== Begin backward ==================== + # Gradient of Bias TODO: make this better + if weight1_bias is not None and weight2_bias is not None and weight3_bias is not None: + att_q_bg = query_g.reshape(-1, query_g.shape[-1]).sum(0) + att_k_bg = key_g.reshape(-1, key_g.shape[-1]).sum(0) + att_v_bg = value_g.reshape(-1, value_g.shape[-1]).sum(0) + else: + att_q_bg = None + att_k_bg = None + att_v_bg = None + + # Quantize the RoPE and FlashAttention Output. grad_input and grad_weight requires different data layout. + query_g, query_gs, query_g_t = fp8_quantize_pertensor_transpose( + query_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False + ) + key_g, key_gs, key_g_t = fp8_quantize_pertensor_transpose( + key_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False + ) + value_g, value_gs, value_g_t = fp8_quantize_pertensor_transpose( + value_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False + ) + + # Linear Layer QKV Projection + if qargs.weight_memory_efficient: + weight1_t, weight1_s = fp8_division_transpose( + weight1_t, qargs.group_size, fwobits["fwbit"], weight1_s, only_transposed=True + ) + weight2_t, weight2_s = fp8_division_transpose( + weight2_t, qargs.group_size, fwobits["fwbit"], weight2_s, only_transposed=True + ) + weight3_t, weight3_s = fp8_division_transpose( + weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True + ) + + fc1_g1, att_q_wg = fp8_linear_backward( + ln_x_t, ln_s, query_g, query_gs, query_g_t, weight1_t, weight1_s, group_size + ) + fc1_g2, att_k_wg = fp8_linear_backward(ln_x_t, ln_s, key_g, key_gs, key_g_t, weight2_t, weight2_s, group_size) + fc1_g3, att_v_wg = fp8_linear_backward( + ln_x_t, ln_s, value_g, value_gs, value_g_t, weight3_t, weight3_s, group_size + ) + + fc1_g = fc1_g1 + fc1_g2 + fc1_g3 + + # LayerNorm + in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc1_g, rms_weight, rstd, group_size, num_warps) + + # Add the gradient together, and prepare the input of the next layer. + in_g, in_sg, in_sg_g16 = fp8_add_Ig16_Ifp_Opt( + q_grad, s_grad, in_g, group_size, fwobits["babit"], stochastic=False + ) + + # for autograd. forward's data type should be the same of backward tensor. this will not change the actual binary representation. + in_g = in_g.view(torch.float8_e4m3fn) + + # Although the next operator is a linear layer in MLPResidual module, we return in_sg_g16 to make the size compatible with the forward. Otherwise it will not pass autograd. + return ( + in_g, + in_sg_g16, + att_q_wg, + None, + None, + None, + att_q_bg, + att_k_wg, + None, + None, + None, + att_k_bg, + att_v_wg, + None, + None, + None, + att_v_bg, + rms_weight_grad, + None, + None, + None, + None, + None, + None, + ) + + +class FP8ActivationResidualQwen2AfterAttentionResidual(FP8CacheWeightModule): + """ + This is a typical transformer attention module that contains (1) Residual (2) 1 * Linear layers + """ + + def __init__(self, config: FP8ActivationResidualQwen2Config, qargs: QuantizationConfig, layer_id): + super().__init__(config, qargs, layer_id) + + self.qargs = qargs + self.fwobits = { + "fabit": self.qargs.fabit, + "fwbit": self.qargs.fwbit, + "fobit": self.qargs.fobit, + "babit": self.qargs.babit, + "bwbit": self.qargs.bwbit, + "bobit": self.qargs.bobit, + } + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def forward(self, re_qx, re_sx, in_x): + if self.training: + if self.qargs.weight_memory_efficient: + # prepare for the weight + with torch.no_grad(): + weight4_s = self.prepare_weight(self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch) + + return _FP8ActivationResidualQwen2AfterAttentionResidual.apply( + re_qx, + re_sx, + in_x, + self.o_proj.weight, + None, + None, + weight4_s, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + # prepare for the weight + with torch.no_grad(): + weight4, weight4_t, weight4_s = self.prepare_weight( + self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch + ) + + return _FP8ActivationResidualQwen2AfterAttentionResidual.apply( + re_qx, + re_sx, + in_x, + self.o_proj.weight, + weight4, + weight4_t, + weight4_s, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + return re_x + self.attn_out(in_x), None, None + + +class _FP8ActivationResidualQwen2AfterAttentionResidual(torch.autograd.Function): + @staticmethod + def forward( + ctx, + re_qx, + re_sx, + flash_x, + weight4_origin, + weight4, + weight4_t, + weight4_s, + group_size, + fwobits, + layer_id, + config, + qargs, + ): + # Quantize the FlashAttention Output + flash_qx, flash_s, _ = fp8_quantize_pertensor( + flash_x, group_size, fwobits["fabit"] + ) # Modified to make it memory efficient + + # # Attention Projection Linear Layer + if qargs.weight_memory_efficient: + assert weight4 is None # memory efficient + weight4, weight4_s = fp8_division(weight4_origin, qargs.group_size, fwobits["fwbit"], weight4_s) + fc4_x = fp8_linear_forward(flash_qx, flash_s, weight4, weight4_s, False, group_size) # + + # import IPython + # IPython.embed() + # Add the activations together + fp_x, (out_x, out_s) = fp8_add_Ig16_Ifp_Ofp_Og16(re_qx, re_sx, fc4_x, flash_qx.dtype, group_size) + + # ==================== save for backward ==================== + ctx.save_for_backward(flash_x, flash_s) + if qargs.weight_memory_efficient: + assert weight4_t is None + ctx.weight = weight4_origin, weight4_s + else: + ctx.weight = weight4_t, weight4_s + ctx.group_size = group_size + ctx.fwobits = fwobits + ctx.utils = fwobits, layer_id, config, qargs + + # For autograd + out_x = out_x.view(torch.float8_e4m3fn) + + return fp_x, out_x, out_s + + @staticmethod + def backward(ctx, fp_grad, out_g, out_gs): + flash_x, flash_s = ctx.saved_tensors + weight4_t, weight4_s = ctx.weight + group_size = ctx.group_size + fwobits = ctx.fwobits + fwobits, layer_id, config, qargs = ctx.utils + + # for autograd + if fwobits["babit"] == "E5M2": + # out_g = out_g.to(torch.float8_e5m2) + out_g = out_g.view(torch.float8_e5m2) + else: + raise ValueError("babit should be E5M2") + out_gs_max = out_gs.max() + + # ==================== Begin backward ==================== + # Output Projection + out_g_t = fp8_transpose(out_g, transpose_output_2d=True) + + # We do not save an extra flash_x to save the memory usage + flash_x_t, flash_s = fp8_division_transpose( + flash_x, group_size, fwobits["fabit"], flash_s, stochastic=False, only_transposed=True + ) + + if qargs.weight_memory_efficient: + weight4_t, weight4_s = fp8_division_transpose( + weight4_t, qargs.group_size, fwobits["fwbit"], weight4_s, only_transposed=True + ) + fc4_g, attn_out_wg = fp8_linear_backward( + flash_x_t, flash_s, out_g, out_gs_max, out_g_t, weight4_t, weight4_s, group_size + ) + + return fp_grad, fc4_g, attn_out_wg, None, None, None, None, None, None, None, None + + +class FP8ActivationResidualQwen2MLPResidual(FP8CacheWeightModule): + """ + This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 2 / 3 * Linear layers + (4) GELU / Silu Activation + """ + + def __init__(self, config: FP8ActivationResidualQwen2Config, qargs: QuantizationConfig, layer_id, hidden_size: int): + super().__init__(config, qargs, layer_id) + + self.qargs = qargs + self.fwobits = { + "fabit": self.qargs.fabit, + "fwbit": self.qargs.fwbit, + "fobit": self.qargs.fobit, + "babit": self.qargs.babit, + "bwbit": self.qargs.bwbit, + "bobit": self.qargs.bobit, + } + + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.training = True + + # below is only used when training = False + assert config.hidden_act == "silu", "We only support silu activation currently" + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, re_x, x, s, rmsnorm_weight): + if self.training: + if self.qargs.weight_memory_efficient: # prepare for the weight + with torch.no_grad(): + weight1_s = self.prepare_weight(self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch) + weight2_s = self.prepare_weight(self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch) + weight3_s = self.prepare_weight(self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch) + + return _FP8ActivationResidualQwen2MLPResidual.apply( + re_x, + x, + s, + self.gate_proj.weight, + None, + None, + weight1_s, + self.up_proj.weight, + None, + None, + weight2_s, + self.down_proj.weight, + None, + None, + weight3_s, + rmsnorm_weight, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + # prepare for the weight + with torch.no_grad(): + weight1, weight1_t, weight1_s = self.prepare_weight( + self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch + ) + weight2, weight2_t, weight2_s = self.prepare_weight( + self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch + ) + weight3, weight3_t, weight3_s = self.prepare_weight( + self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch + ) + + return _FP8ActivationResidualQwen2MLPResidual.apply( + re_x, + x, + s, + self.gate_proj.weight, + weight1, + weight1_t, + weight1_s, + self.up_proj.weight, + weight2, + weight2_t, + weight2_s, + self.down_proj.weight, + weight3, + weight3_t, + weight3_s, + rmsnorm_weight, + self.qargs.group_size, + self.fwobits, + self.layer_id, + self.config, + self.qargs, + ) + else: + raise NotImplementedError("Need TODO") + og_x = re_x + re_x = self.ff_norm(re_x) + re_x = self.ff_proj(re_x) + re_x = self.act(re_x) + re_x = self.ff_out(re_x) + re_x = og_x + re_x + return re_x, None, None + + +class _FP8ActivationResidualQwen2MLPResidual(torch.autograd.Function): + @staticmethod + def forward( + ctx, + re_x, + in_x, + in_s, + weight1_origin, + weight1, + weight1_t, + weight1_s, + weight2_origin, + weight2, + weight2_t, + weight2_s, + weight3_origin, + weight3, + weight3_t, + weight3_s, + rmsnorm_weight, + group_size, + fwobits, + layer_id, + config, + qargs, + eps=1e-5, + ): + # For autograd + if fwobits["fabit"] == "E4M3": + # in_x = in_x.to(torch.float8_e4m3fn) + in_x = in_x.view(torch.float8_e4m3fn) + else: + raise ValueError("fabit should be E4M3") + + # LayerNorm + ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward( + in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True + ) + + # Linear Layer of Up Projection and Gate Projection. They are fused as one linear layer. + if qargs.weight_memory_efficient: + assert weight1 is None and weight2 is None and weight3 is None # memory efficient + weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s) + weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s) + weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s) + + gate_x, gate_s = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, True, group_size) # Gate Proj + up_x, up_s = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, True, group_size) # Up Proj + + # silu Activation + silu_x, silu_s = fp8_silu_forward(gate_x, gate_s, group_size) + + # Element-wise Multiplication + mul_x, mul_s, mul_x_t = fp8_mul_forward(silu_x, silu_s, up_x, up_s, group_size, transpose_output_2d=True) + + # Output Projection + if weight3 is None: # memory efficient + weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s) + fc3_x = fp8_linear_forward(mul_x, mul_s, weight3, weight3_s, False, group_size) + + # Add the activation together + out_x, out_s = fp8_add_Ifp_Ifp_Og16(re_x, fc3_x, mul_x.dtype, group_size) + + # ==================== save for backward ==================== + ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s) + + ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s) + if ( + qargs.weight_memory_efficient + ): # Weight_1/2_origin will not be saved twice, so it will be more memory efficient. + assert weight1_t is None and weight2_t is None and weight3_t is None + ctx.weight = (weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s) + else: # Weight1/2_t is different from the origin weight, so saving it will consumes additional memory footprint. + ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s) + + ctx.group_size = group_size + ctx.ln_utils = ln_utils + ctx.utils = fwobits, layer_id, config, qargs + + out_x = out_x.view(torch.float8_e4m3fn) + + return out_x, out_s + + @staticmethod + def backward(ctx, out_g, out_gs): + fwobits, layer_id, config, qargs = ctx.utils + + in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s = ctx.saved_tensors + + (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s) = ctx.weight + group_size = ctx.group_size + rms_weight, rstd, num_warps = ctx.ln_utils + fwobits, layer_id, config, qargs = ctx.utils + + # For autograd + if fwobits["babit"] == "E5M2": + # out_g = out_g.to(torch.float8_e5m2) + out_g = out_g.view(torch.float8_e5m2) + else: + raise ValueError("babit should be E5M2") + out_gs_max = out_gs.max() + + # ==================== Begin backward ==================== + # Output Projection + out_gs = out_gs.max() + out_g_t = fp8_transpose(out_g, transpose_output_2d=True) + + if qargs.weight_memory_efficient: + weight3_t, weight3_s = fp8_division_transpose( + weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True + ) + fc3_g, weight3_grad = fp8_linear_backward( + mul_x_t, mul_s, out_g, out_gs_max, out_g_t, weight3_t, weight3_s, group_size + ) + + # [MEM TEST] + del out_g, out_g_t, weight3_t + + # Element-wise Multiplication, 1 means gate, 2 means up + mul_g1, (mul_g2, mul_gs2, mul_g2_t) = fp8_mul_backward( + silu_x, silu_s, up_x, up_s, fc3_g, group_size, fwobits["babit"], output_quantized_transpose=True + ) + + # Silu activation + silu_g, silu_gs, silu_g_t = fp8_silu_backward( + gate_x, gate_s, mul_g1, group_size, fwobits["babit"], output_quantized_transpose=True + ) + + # Linear Layer of Up and Gate Projection + if qargs.weight_memory_efficient: + weight1_t, weight1_s = fp8_division_transpose( + weight1_t, group_size, fwobits["fwbit"], weight1_s, only_transposed=True + ) + weight2_t, weight2_s = fp8_division_transpose( + weight2_t, group_size, fwobits["fwbit"], weight2_s, only_transposed=True + ) + + # Gate Proj + fc1_g, weight1_grad = fp8_linear_backward( + ln_x_t, ln_s, silu_g, silu_gs, silu_g_t, weight1_t, weight1_s, group_size + ) + fc2_g, weight2_grad = fp8_linear_backward( + ln_x_t, ln_s, mul_g2, mul_gs2, mul_g2_t, weight2_t, weight2_s, group_size + ) + + fc_g = fc1_g + fc2_g + + # layerNorm + in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc_g, rms_weight, rstd, group_size, num_warps) + + # Add the gradient together + re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt( + out_g, out_gs_max, in_g, group_size, fwobits["babit"], stochastic=False + ) + + in_g = in_g.view(torch.float8_e4m3fn) + + return ( + re_g, + in_g, + in_sg_g16, + weight1_grad, + None, + None, + None, + weight2_grad, + None, + None, + None, + weight3_grad, + None, + None, + None, + rms_weight_grad, + None, + None, + None, + None, + None, + None, + ) + + +class FP8ActivationResidualQwen2AttentionWithoutLinear(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.rotary_emb = Qwen2RotaryEmbedding(config=self.config) + + def forward( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class FP8ActivationResidualQwen2FlashAttention2WithoutLinear(FP8ActivationResidualQwen2AttentionWithoutLinear): + """ + Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ): + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + kv_seq_len = key_states.shape[-2] + cache_position[0] + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + else: + sliding_window = None + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class FP8ActivationResidualQwen2SdpaAttentionWithoutLinear(FP8ActivationResidualQwen2AttentionWithoutLinear): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Qwen2Attention.forward + def forward( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + query_states=query_states, + key_states=key_states, + value_states=value_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + return attn_output, None, past_key_value + + +FP8LINEARRESIDUALQWEN2_ATTENTION_CLASSES = { + "eager": FP8ActivationResidualQwen2AttentionWithoutLinear, + "flash_attention_2": FP8ActivationResidualQwen2FlashAttention2WithoutLinear, + "sdpa": FP8ActivationResidualQwen2SdpaAttentionWithoutLinear, +} + + +class FP8ActivationResidualQwen2DecoderLayer(nn.Module): + def __init__(self, config: FP8ActivationResidualQwen2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = FP8LINEARRESIDUALQWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.qargs = QuantizationConfig(**config.coat_fp8_args) + self.BeforeAttention = FP8ActivationResidualQwen2BeforeAttentionResidual(config, self.qargs, layer_idx) + self.AfterAttention = FP8ActivationResidualQwen2AfterAttentionResidual(config, self.qargs, layer_idx) + self.MLPResidual = FP8ActivationResidualQwen2MLPResidual(config, self.qargs, layer_idx, self.hidden_size) + + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + quant_hidden_states: torch.Tensor, + scale_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + # Coat: The residual, LayerNorm, and the Q/K/V Projection Linear Layer + residual_quant, residual_scale, query_states, key_states, value_states = self.BeforeAttention( + quant_hidden_states, scale_hidden_states, self.input_layernorm.weight + ) + + # Self Attention without any linear layer + hidden_states, self_attn_weights, present_key_value = self.self_attn( + query_states=query_states, + key_states=key_states, + value_states=value_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + # Coat: The Output Projection Linear Layer and Residual + hidden_states, quant_hidden_states, scale_hidden_states = self.AfterAttention( + residual_quant, residual_scale, hidden_states + ) + + # Residual Connection, LayerNorm, and the whole MLP module + quant_hidden_states, scale_hidden_states = self.MLPResidual( + hidden_states, quant_hidden_states, scale_hidden_states, self.post_attention_layernorm.weight + ) + + outputs = ((quant_hidden_states, scale_hidden_states),) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class FP8ActivationResidualQwen2PreTrainedModel(Qwen2PreTrainedModel): + config_class = FP8ActivationResidualQwen2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["FP8ActivationResidualQwen2DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class FP8ActivationResidualQwen2Model(FP8ActivationResidualQwen2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: Qwen2Config + """ + + def __init__(self, config: FP8ActivationResidualQwen2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [FP8ActivationResidualQwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2RotaryEmbedding(config=config) + + # Quantize + self.qargs = QuantizationConfig(**config.coat_fp8_args) + self.quantize_input_before_block = Coat_quantize_bgn(self.qargs) + self.quantize_output_after_block = Coat_quantize_end(self.qargs) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + # Prepare the input for Coat decoderlayer + quant_hidden_states, scale_hidden_states = self.quantize_input_before_block(hidden_states) + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + quant_hidden_states, + scale_hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + quant_hidden_states, + scale_hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # Summarize the output of the Decoder Layer + hidden_states = self.quantize_output_after_block(quant_hidden_states, scale_hidden_states) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + _update_causal_mask = Qwen2Model._update_causal_mask + + +class FP8ActivationResidualQwen2ForCausalLM(FP8ActivationResidualQwen2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = FP8ActivationResidualQwen2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + forward = Qwen2ForCausalLM.forward + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation + prepare_inputs_for_generation = Qwen2ForCausalLM.prepare_inputs_for_generation + + +AutoConfig.register("fp8activationresidual_qwen2", FP8ActivationResidualQwen2Config) +AutoModel.register(FP8ActivationResidualQwen2Config, FP8ActivationResidualQwen2Model) +AutoModelForCausalLM.register(FP8ActivationResidualQwen2Config, FP8ActivationResidualQwen2ForCausalLM) + + +def make_state_dict_compatible(state_dict: dict[str, torch.Tensor]): + compatible_state_dict = {} + + for key, value in state_dict.items(): + if fnmatch(key, "*self_attn.q_proj*"): + new_key = key.replace("self_attn.q_proj", "BeforeAttention.q_proj") + elif fnmatch(key, "*self_attn.k_proj*"): + new_key = key.replace("self_attn.k_proj", "BeforeAttention.k_proj") + elif fnmatch(key, "*self_attn.v_proj*"): + new_key = key.replace("self_attn.v_proj", "BeforeAttention.v_proj") + elif fnmatch(key, "*self_attn.o_proj*"): + new_key = key.replace("self_attn.o_proj", "AfterAttention.o_proj") + + elif fnmatch(key, "*mlp.gate_proj*"): + new_key = key.replace("mlp.gate_proj", "MLPResidual.gate_proj") + elif fnmatch(key, "*mlp.up_proj*"): + new_key = key.replace("mlp.up_proj", "MLPResidual.up_proj") + elif fnmatch(key, "*mlp.down_proj*"): + new_key = key.replace("mlp.down_proj", "MLPResidual.down_proj") + + else: + new_key = key + + compatible_state_dict[new_key] = value + + return compatible_state_dict diff --git a/llava/model/language_model/fp8linearqwen2.py b/llava/model/language_model/fp8linearqwen2.py new file mode 100644 index 0000000000000000000000000000000000000000..0ae93bb8d6798e9f327198abc405866e4a1527e2 --- /dev/null +++ b/llava/model/language_model/fp8linearqwen2.py @@ -0,0 +1,356 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Qwen2 model.""" + +import math +from functools import lru_cache +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config +from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2Attention, + Qwen2DecoderLayer, + Qwen2FlashAttention2, + Qwen2ForCausalLM, + Qwen2MLP, + Qwen2Model, + Qwen2PreTrainedModel, + Qwen2RMSNorm, + Qwen2RotaryEmbedding, + Qwen2SdpaAttention, + apply_rotary_pos_emb, + repeat_kv, + rotate_half, +) +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) + +from ..liger.cross_entropy import LigerForCausalLMLoss +from ..qlinear_te import QLinearTE +from .configuration_quantize import QuantizationConfig + +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward + +logger = logging.get_logger(__name__) + + +class FP8LinearQwen2Config(Qwen2Config): + model_type = "fp8linear_qwen2" + + def __init__( + self, + coat_fp8_args=None, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + **kwargs, + ): + super().__init__( + vocab_size, + hidden_size, + intermediate_size, + num_hidden_layers, + num_attention_heads, + num_key_value_heads, + hidden_act, + max_position_embeddings, + initializer_range, + rms_norm_eps, + use_cache, + tie_word_embeddings, + rope_theta, + rope_scaling, + use_sliding_window, + sliding_window, + max_window_layers, + attention_dropout, + **kwargs, + ) + + self.coat_fp8_args = coat_fp8_args + + +# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2 +class FP8LinearQwen2MLP(Qwen2MLP): + def __init__(self, config, layer_idx): + super().__init__(config) + # self.gate_proj = te.Linear(self.hidden_size, self.intermediate_size, bias=False) + # self.up_proj = te.Linear(self.hidden_size, self.intermediate_size, bias=False) + # self.down_proj = te.Linear(self.intermediate_size, self.hidden_size, bias=False) + + self.gate_proj = QLinearTE( + self.hidden_size, self.intermediate_size, bias=False, args=config.coat_fp8_args, layer_idx=layer_idx + ) + self.up_proj = QLinearTE( + self.hidden_size, self.intermediate_size, bias=False, args=config.coat_fp8_args, layer_idx=layer_idx + ) + self.down_proj = QLinearTE( + self.intermediate_size, self.hidden_size, bias=False, args=config.coat_fp8_args, layer_idx=layer_idx + ) + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class FP8LinearQwen2Attention(Qwen2Attention): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: FP8LinearQwen2Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + + self.q_proj = QLinearTE( + self.hidden_size, + self.num_heads * self.head_dim, + bias=True, + args=config.coat_fp8_args, + layer_idx=layer_idx, + ) + self.k_proj = QLinearTE( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=True, + args=config.coat_fp8_args, + layer_idx=layer_idx, + ) + self.v_proj = QLinearTE( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=True, + args=config.coat_fp8_args, + layer_idx=layer_idx, + ) + self.o_proj = QLinearTE( + self.num_heads * self.head_dim, + self.hidden_size, + bias=False, + args=config.coat_fp8_args, + layer_idx=layer_idx, + ) + + forward = Qwen2Attention.forward + + +class FP8LinearQwen2FlashAttention2(FP8LinearQwen2Attention): + """ + Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + forward = Qwen2FlashAttention2.forward + + +# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2 +class FP8LinearQwen2SdpaAttention(FP8LinearQwen2Attention): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Qwen2Attention.forward + forward = Qwen2SdpaAttention.forward + + +FP8LINEARQWEN2_ATTENTION_CLASSES = { + "eager": FP8LinearQwen2Attention, + "flash_attention_2": FP8LinearQwen2FlashAttention2, + "sdpa": FP8LinearQwen2SdpaAttention, +} + + +class FP8LinearQwen2DecoderLayer(nn.Module): + def __init__(self, config: FP8LinearQwen2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = FP8LINEARQWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.mlp = FP8LinearQwen2MLP(config, layer_idx) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + forward = Qwen2DecoderLayer.forward + + +class FP8LinearQwen2PreTrainedModel(Qwen2PreTrainedModel): + config_class = FP8LinearQwen2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["FP8LinearQwen2DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class FP8LinearQwen2Model(FP8LinearQwen2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: Qwen2Config + """ + + def __init__(self, config: FP8LinearQwen2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [FP8LinearQwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2RotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + forward = Qwen2Model.forward + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + _update_causal_mask = Qwen2Model._update_causal_mask + + +class FP8LinearQwen2ForCausalLM(FP8LinearQwen2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = FP8LinearQwen2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @property + @lru_cache + def loss_function(self): + return LigerForCausalLMLoss + + forward = Qwen2ForCausalLM.forward + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation + prepare_inputs_for_generation = Qwen2ForCausalLM.prepare_inputs_for_generation + + +AutoConfig.register("fp8linear_qwen2", FP8LinearQwen2Config) +AutoModel.register(FP8LinearQwen2Config, FP8LinearQwen2Model) +AutoModelForCausalLM.register(FP8LinearQwen2Config, FP8LinearQwen2ForCausalLM) diff --git a/llava/model/language_model/llava_llama.py b/llava/model/language_model/llava_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..b8fbe3f845abe162edd4d891cca35a752b0966da --- /dev/null +++ b/llava/model/language_model/llava_llama.py @@ -0,0 +1,169 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is modified from https://github.com/haotian-liu/LLaVA/ + + +import os +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +import torch +from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithPast + +from llava.model.loss import soft_cross_entropy +from llava.model.utils.packing import set_seqlens_in_batch +from llava.train.sequence_parallel.globals import get_pg_manager +from llava.utils.logging import logger + +from ...train.utils import calculate_loss_weight +from ..configuration_llava import LlavaConfig +from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel + + +class LlavaLlamaConfig(LlavaConfig): + model_type = "llava_llama" + + +# FIXME we will follow the convention to add a new class for CausalLM in the future +class LlavaLlamaModel(LlavaMetaModel, LlavaMetaForCausalLM, PreTrainedModel): + config_class = LlavaLlamaConfig + main_input_name = "input_embeds" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + + def __init__(self, config: LlavaLlamaConfig = None, *args, **kwargs) -> None: + super().__init__(config) + self.init_vlm(config=config, *args, **kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *model_args, + config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + ignore_mismatched_sizes: bool = False, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + use_safetensors: bool = None, + **kwargs, + ): + if hasattr(cls, "load_pretrained"): + return cls.load_pretrained( + pretrained_model_name_or_path, + *model_args, + config=config, + cache_dir=cache_dir, + ignore_mismatched_sizes=ignore_mismatched_sizes, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + use_safetensors=use_safetensors, + **kwargs, + ) + return super(LlavaLlamaModel).from_pretrained( + pretrained_model_name_or_path, + *model_args, + config=config, + cache_dir=cache_dir, + ignore_mismatched_sizes=ignore_mismatched_sizes, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + use_safetensors=use_safetensors, + **kwargs, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + media: Optional[Dict[str, List[torch.Tensor]]] = None, + images: Optional[torch.FloatTensor] = None, + media_config: Optional[List] = None, + attention_mask: Optional[torch.Tensor] = None, + media_meta: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + packing: bool = True, + force_packing: bool = False, + seqlens_in_batch: Optional[torch.LongTensor] = None, + dpo_forward: bool = False, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + self.freezed_module_patch() + + if images is not None: + if media is not None: + raise ValueError("Both 'media' and 'images' are provided. Please provide only one.") + logger.warning("The 'images' argument is deprecated. Please use 'media' instead.") + media = {"image": images} + + if media_config is None: + media_config = defaultdict(dict) + if inputs_embeds is None: + inputs_embeds, labels, attention_mask = self._embed(input_ids, media, media_config, labels, attention_mask,media_meta) + + if force_packing or (packing and self.training and not dpo_forward): + if seqlens_in_batch is None: + seqlens_in_batch = torch.sum(attention_mask, dim=1) + set_seqlens_in_batch(seqlens_in_batch) + + (inputs_embeds, attention_mask, position_ids, labels) = self.repack_multimodal_data( + inputs_embeds, attention_mask, position_ids, labels + ) + + outputs = self.llm( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + labels=labels, + **kwargs, + ) + + if self.training and getattr(self.config, "time_token_ids", []): + outputs.loss = soft_cross_entropy( + outputs.logits, + labels, + soft_tokens=self.config.time_token_ids, + std=self.config.soft_ce_std, + ) + + # Loss rescale for SP + if get_pg_manager() is not None: + loss_weight = calculate_loss_weight(labels) + outputs.loss = outputs.loss * loss_weight + + if dpo_forward: + return outputs.logits, labels + + return outputs + + +AutoConfig.register("llava_llama", LlavaLlamaConfig) +AutoModel.register(LlavaLlamaConfig, LlavaLlamaModel) diff --git a/llava/model/language_model/qllama.py b/llava/model/language_model/qllama.py new file mode 100644 index 0000000000000000000000000000000000000000..db823b94634ec5cfeb589319e947f64a473f3a30 --- /dev/null +++ b/llava/model/language_model/qllama.py @@ -0,0 +1,312 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch LLaMA model.""" +import math +import os +import time +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from flash_attn import flash_attn_func, flash_attn_varlen_func +from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers.activations import ACT2FN +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaFlashAttention2, + LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaLinearScalingRotaryEmbedding, + LlamaMLP, + LlamaModel, + LlamaPreTrainedModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, + LlamaSdpaAttention, + apply_rotary_pos_emb, + repeat_kv, + rotate_half, +) +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) + +from ..qlinear_te import QLinearTE + +try: + import transformer_engine.pytorch as te +except: + pass +from ..qfunction import * + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "QLlamaConfig" + + +class QLlamaConfig(LlamaConfig): + model_type = "qllama" + + +class QLlamaMLP(LlamaMLP): + def __init__(self, config, layer_idx): + super().__init__(config) + self.layer_idx = layer_idx + + # self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + # self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + # self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.gate_proj = QLinearTE( + self.hidden_size, self.intermediate_size, bias=False, args=config, layer_idx=layer_idx + ) + self.up_proj = QLinearTE(self.hidden_size, self.intermediate_size, bias=False, args=config, layer_idx=layer_idx) + self.down_proj = QLinearTE( + self.intermediate_size, self.hidden_size, bias=False, args=config, layer_idx=layer_idx + ) + + +class QLlamaAttention(LlamaAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: QLlamaConfig, layer_idx): + super().__init__(config) + self.layer_idx = layer_idx + + # self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + # self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + # self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + # self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + + self.q_proj = QLinearTE( + self.hidden_size, + self.num_heads * self.head_dim, + bias=config.attention_bias, + args=config, + layer_idx=layer_idx, + ) + self.k_proj = QLinearTE( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + args=config, + layer_idx=layer_idx, + ) + self.v_proj = QLinearTE( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + args=config, + layer_idx=layer_idx, + ) + self.o_proj = QLinearTE( + self.num_heads * self.head_dim, + self.hidden_size, + bias=config.attention_bias, + args=config, + layer_idx=layer_idx, + ) + + +class QLlamaFlashAttention2(QLlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + forward = LlamaFlashAttention2.forward + + +class QLlamaSdpaAttention(QLlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + forward = LlamaSdpaAttention.forward + + +QLLAMA_ATTENTION_CLASSES = { + "eager": QLlamaAttention, + "flash_attention_2": QLlamaFlashAttention2, + "sdpa": QLlamaSdpaAttention, +} + + +class QLlamaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: QLlamaConfig, layer_idx): + super().__init__(config, layer_idx=layer_idx) + self.hidden_size = config.hidden_size + + self.self_attn = QLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = QLlamaMLP(config, layer_idx) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.layer_idx = layer_idx + + forward = LlamaDecoderLayer.forward + + +class QLlamaPreTrainedModel(LlamaPreTrainedModel): + config_class = QLlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["QLlamaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear) or isinstance(module, QLinearTE): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class QLlamaModel(QLlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: QLlamaConfig + """ + + def __init__(self, config: QLlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [QLlamaDecoderLayer(config, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + _update_causal_mask = LlamaModel._update_causal_mask + forward = LlamaModel.forward + + +class QLlamaForCausalLM(QLlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = QLlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.forward_step_id = 0 + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + forward = LlamaForCausalLM.forward + prepare_inputs_for_generation = LlamaForCausalLM.prepare_inputs_for_generation + + +class QLlamaForSequenceClassification(QLlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = QLlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + forward = LlamaForSequenceClassification.forward + + +AutoConfig.register("qllama", QLlamaConfig) +AutoModel.register(QLlamaConfig, QLlamaModel) +AutoModelForCausalLM.register(QLlamaConfig, QLlamaForCausalLM) diff --git a/llava/model/language_model/qllava_qllama.py b/llava/model/language_model/qllava_qllama.py new file mode 100644 index 0000000000000000000000000000000000000000..2fe4f0785352cc1d9fef4b20afdfb6183672fc32 --- /dev/null +++ b/llava/model/language_model/qllava_qllama.py @@ -0,0 +1,119 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is modified from https://github.com/haotian-liu/LLaVA/ + + +import inspect +import os +import os.path as osp +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from transformers import ( + AutoConfig, + AutoModel, + GenerationConfig, + LlamaConfig, + LlamaForCausalLM, + PretrainedConfig, + PreTrainedModel, +) +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.modeling_utils import ContextManagers, no_init_weights + +from ..configuration_llava import LlavaConfig +from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel +from ..utils import get_model_config, get_model_config_fp8 +from .builder import build_llm_and_tokenizer +from .llava_llama import LlavaLlamaConfig, LlavaLlamaModel + +quantize_args_to_model_class = { + "fp8Linear_llama": "QLlamaForCausalLM", + "fp8LinearAndActivation_llama": "QMemLlamaForCausalLM", + "fp8Linear_qwen2": "FP8LinearQwen2ForCausalLM", + "fp8Activation_qwen2": "FP8ActivationQwen2ForCausalLM", + "fp8ActivationResidual_qwen2": "FP8ActivationResidualQwen2ForCausalLM", +} + + +class QLlavaLlamaConfig(LlavaLlamaConfig): + model_type = "qllava_qllama" + + +## FIXME we will follow the convention to add a new class for CausalLM in the future +class QLlavaLlamaModel(LlavaLlamaModel): + config_class = QLlavaLlamaConfig + main_input_name = "input_embeds" + supports_gradient_checkpointing = True + + def __init__(self, config: QLlavaLlamaConfig = None, model_args=None, *args, **kwargs) -> None: + PreTrainedModel.__init__(self, config) + return self.init_vlm(config=config, model_args=model_args, *args, **kwargs) + + # rewrite to support QLlama + def init_vlm(self, config: PreTrainedModel = None, model_args=None, *args, **kwargs): + # TODO(ligeng): figure out how from_config and from_pretrained works in HF implementation. + if hasattr(self, "llm") or hasattr(self, "vision_tower") or hasattr(self, "mm_projector"): + # already initialized, skipped + return + + model_dtype = getattr(config, "model_dtype", "torch.float16") + if not hasattr(config, "model_dtype"): + warnings.warn("model_dtype not found in config, defaulting to torch.float16.") + config.model_dtype = model_dtype + + if model_args.quantize_model in ["fp8Activation_qwen2", "fp8ActivationResidual_qwen2"]: + cfgs = get_model_config_fp8(config) # The first cfg is fp8 + else: + cfgs = get_model_config(config) + if len(cfgs) == 3: + llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs + elif len(cfgs) == 4: + llm_cfg, vision_tower_cfg, mm_projector_cfg, fp8_llm_cfg = cfgs + kwargs.update({"fp8_llm_cfg": fp8_llm_cfg}) + else: + raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.") + + kwargs.update( + { + "quantize_model_class": quantize_args_to_model_class[model_args.quantize_model], + "model_args": model_args, + } + ) + self.llm, self.tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs) + + + for name, module in self.llm.named_modules(): + module.layer_name = name + + self.pad_to_multiple_of = model_args.pad_to_multiple_of + + self.post_config() + self.is_loaded = True + + assert ( + self.llm is not None or self.vision_tower is not None or self.mm_projector is not None + ), "At least one of the components must be instantiated." + + +AutoConfig.register("qllava_qllama", QLlavaLlamaConfig) +AutoModel.register(QLlavaLlamaConfig, QLlavaLlamaModel) diff --git a/llava/model/language_model/qmemllama.py b/llava/model/language_model/qmemllama.py new file mode 100644 index 0000000000000000000000000000000000000000..26e794840df2230013a66f643745e9f50ffbeef9 --- /dev/null +++ b/llava/model/language_model/qmemllama.py @@ -0,0 +1,679 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch LLaMA model.""" +import math +import os +import time +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from flash_attn import flash_attn_func, flash_attn_varlen_func +from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaFlashAttention2, + LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaLinearScalingRotaryEmbedding, + LlamaMLP, + LlamaModel, + LlamaPreTrainedModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, + LlamaSdpaAttention, + apply_rotary_pos_emb, + repeat_kv, + rotate_half, +) +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) + +from ..qlinear_te import QLinearTE + +try: + import transformer_engine.pytorch as te +except: + pass + +from ..quantization import QGELU, QAct_FPin, QAct_FPout, QAdd, QIdentity, QLayerNorm, QLinear, QMul + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "QMemLlamaConfig" + + +class QMemLlamaConfig(LlamaConfig): + model_type = "qmemllama" + + +class QLlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6, args=None, layer_type=None): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.qargs = args + self.QAct_layernorm_in = QAct_FPout(args, layer_type=layer_type + "_in") + self.QAct_layernorm_out = QAct_FPin(args, layer_type=layer_type + "_out") + + def forward(self, hidden_states, s): + hidden_states = self.QAct_layernorm_in(hidden_states, s) + + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = self.weight * hidden_states.to(input_dtype) + + hidden_states, s = self.QAct_layernorm_out(hidden_states) + return hidden_states, s + + +ALL_LAYERNORM_LAYERS.append(QLlamaRMSNorm) + + +class QMemLlamaMLP(LlamaMLP): + def __init__(self, config, layer_idx): + super().__init__(config) + self.layer_idx = layer_idx + + self.gate_proj = QLinear( + self.hidden_size, self.intermediate_size, bias=False, args=config, layer_type="mlp_gate" + ) + self.up_proj = QLinear(self.hidden_size, self.intermediate_size, bias=False, args=config, layer_type="mlp_up") + self.down_proj = QLinear( + self.intermediate_size, self.hidden_size, bias=False, args=config, layer_type="mlp_down" + ) + self.act_fn = ACT2FN[config.hidden_act] + + self.QAct_act_sum = QAct_FPout(config, layer_type="mlp_act_sum") + self.QAct_act_gate = QAct_FPin(config, layer_type="mlp_act_gate") + + self.QAct_act_up = QAct_FPin(config, layer_type="mlp_act_up") + + self.QAct_act_in = QAct_FPout(config, layer_type="mlp_act_in") + self.QAct_act_out = QAct_FPin(config, layer_type="mlp_act_out") + + self.QMul_act = QMul(config, layer_type="mul_act") + + def forward(self, x, s): + if self.config.pretraining_tp > 1: + raise ValueError("Currently Quantization is not implemented for tensor parallel for simplicity") + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + # down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + x = self.QAct_act_sum(x, s) + x_gate, s_gate = self.QAct_act_gate(x) + x_up, s_up = self.QAct_act_up(x) + x_gate, s_gate = self.gate_proj(x_gate, s_gate) + x_gate = self.QAct_act_in(x_gate, s_gate) + x_gate = self.act_fn(x_gate) + x_gate, s_gate = self.QAct_act_out(x_gate) + + x_up, s_up = self.up_proj(x_up, s_up) + x, s = self.QMul_act(x_gate, x_up, s_gate, s_up) + down_proj, s = self.down_proj(x, s) + return down_proj, s + + +class QMemLlamaAttention(LlamaAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: QMemLlamaConfig, layer_idx): + super().__init__(config) + self.layer_idx = layer_idx + + self.q_proj = QLinear( + self.hidden_size, + self.num_heads * self.head_dim, + bias=config.attention_bias, + args=config, + layer_type="attn_q", + ) + self.k_proj = QLinear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + args=config, + layer_type="attn_k", + ) + self.v_proj = QLinear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + args=config, + layer_type="attn_v", + ) + self.o_proj = QLinear( + self.num_heads * self.head_dim, + self.hidden_size, + bias=config.attention_bias, + args=config, + layer_type="attn_proj", + ) + + self.QAct_qkv_sum = QAct_FPout(config, layer_type="attn_qkv_sum") + + self.QAct_q_in = QAct_FPin(config, layer_type="attn_q_in") + self.QAct_k_in = QAct_FPin(config, layer_type="attn_k_in") + self.QAct_v_in = QAct_FPin(config, layer_type="attn_v_in") + + self.QAct_q_out = QAct_FPout(config, layer_type="attn_q_out") + self.QAct_k_out = QAct_FPout(config, layer_type="attn_k_out") + self.QAct_v_out = QAct_FPout(config, layer_type="attn_v_out") + self.QAct_proj_in = QAct_FPin(config, layer_type="attn_proj_in") + + +class QMemLlamaFlashAttention2(QMemLlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + s: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + hidden_states = self.QAct_qkv_sum(hidden_states, s) + + q, sq = self.QAct_q_in(hidden_states) + k, sk = self.QAct_k_in(hidden_states) + v, sv = self.QAct_v_in(hidden_states) + + query_states, sq = self.q_proj(q, sq) + key_states, sk = self.k_proj(k, sk) + value_states, sv = self.v_proj(v, sv) + + query_states = self.QAct_q_out(query_states, sq) + key_states = self.QAct_k_out(key_states, sk) + value_states = self.QAct_v_out(value_states, sv) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + + attn_output = attn_output.to(torch.float32) + attn_output, s = self.QAct_proj_in(attn_output) + attn_output, s = self.o_proj(attn_output, s) + + if not output_attentions: + attn_weights = None + + return attn_output, s, attn_weights, past_key_value + + +class QMemLlamaSdpaAttention(QMemLlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + s: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + hidden_states = self.QAct_qkv_sum(hidden_states, s) + + q, sq = self.QAct_q_in(hidden_states) + k, sk = self.QAct_k_in(hidden_states) + v, sv = self.QAct_v_in(hidden_states) + + query_states, sq = self.q_proj(q, sq) + key_states, sk = self.k_proj(k, sk) + value_states, sv = self.v_proj(v, sv) + + query_states = self.QAct_q_out(query_states, sq) + key_states = self.QAct_k_out(key_states, sk) + value_states = self.QAct_v_out(value_states, sv) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + # attn_output = attn_output.to(torch.float32) + attn_output, s = self.QAct_proj_in(attn_output) + attn_output, s = self.o_proj(attn_output, s) + + return attn_output, s, None, past_key_value + + +QMemLLAMA_ATTENTION_CLASSES = { + "eager": QMemLlamaAttention, + "flash_attention_2": QMemLlamaFlashAttention2, + "sdpa": QMemLlamaSdpaAttention, +} + + +class QMemLlamaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: QMemLlamaConfig, layer_idx): + super().__init__(config, layer_idx=layer_idx) + self.hidden_size = config.hidden_size + + self.self_attn = QMemLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = QMemLlamaMLP(config, layer_idx) + + self.input_layernorm = QLlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, args=config, layer_type="ln_attn" + ) + self.post_attention_layernorm = QLlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, args=config, layer_type="ln_mlp" + ) + + self.QAdd_attn = QAdd(config, layer_type="add_attn") + self.QAdd_mlp = QAdd(config, layer_type="add_mlp") + + self.QAct_reattnout_fx = QAct_FPin(config, layer_type="re_attn_out_fx") + self.QAct_reattnout_re = QAct_FPin(config, layer_type="re_attn_out_re") + + self.QAct_remlpout_fx = QAct_FPin(config, layer_type="re_mlp_out_fx") + self.QAct_remlpout_re = QAct_FPin(config, layer_type="re_mlp_out_re") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual, res = self.QAct_reattnout_re(hidden_states) + hidden_states, s = self.QAct_reattnout_fx(hidden_states) + + hidden_states, s = self.input_layernorm(hidden_states, s) + + # Self Attention + hidden_states, s, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + s=s, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = self.QAdd_attn(residual, hidden_states, res, s) + + # Fully Connected + residual, res = self.QAct_remlpout_re(hidden_states) + hidden_states, s = self.QAct_remlpout_fx(hidden_states) + + hidden_states, s = self.post_attention_layernorm(hidden_states, s) + hidden_states, s = self.mlp(hidden_states, s) + hidden_states = self.QAdd_mlp(residual, hidden_states, res, s) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class QMemLlamaPreTrainedModel(LlamaPreTrainedModel): + config_class = QMemLlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["QMemLlamaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear) or isinstance(module, QLinearTE): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class QMemLlamaModel(QMemLlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: QMemLlamaConfig + """ + + def __init__(self, config: QMemLlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [QMemLlamaDecoderLayer(config, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + _update_causal_mask = LlamaModel._update_causal_mask + forward = LlamaModel.forward + + +class QMemLlamaForCausalLM(QMemLlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = QMemLlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.forward_step_id = 0 + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + forward = LlamaForCausalLM.forward + prepare_inputs_for_generation = LlamaForCausalLM.prepare_inputs_for_generation + + +class QMemLlamaForSequenceClassification(QMemLlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = QMemLlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + forward = LlamaForSequenceClassification.forward + + +AutoConfig.register("qmemllama", QMemLlamaConfig) +AutoModel.register(QMemLlamaConfig, QMemLlamaModel) +AutoModelForCausalLM.register(QMemLlamaConfig, QMemLlamaForCausalLM) diff --git a/llava/model/language_model/realqmemllama.py b/llava/model/language_model/realqmemllama.py new file mode 100644 index 0000000000000000000000000000000000000000..9eb71d015f8f7a0dcffc466b4093f67970537335 --- /dev/null +++ b/llava/model/language_model/realqmemllama.py @@ -0,0 +1,674 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch LLaMA model.""" +import math +import os +import time +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from flash_attn import flash_attn_func, flash_attn_varlen_func +from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaFlashAttention2, + LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaLinearScalingRotaryEmbedding, + LlamaMLP, + LlamaModel, + LlamaPreTrainedModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, + LlamaSdpaAttention, + apply_rotary_pos_emb, + repeat_kv, + rotate_half, +) +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) + +from ..qlinear_te import QLinearTE + +try: + import transformer_engine.pytorch as te +except: + pass +from ..quantization import QGELU, QAct_FPin, QAct_FPout, QAdd, QIdentity, QLayerNorm, QLinear, QMul + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "QMemLlamaConfig" + + +class QMemLlamaConfig(LlamaConfig): + model_type = "qmemllama" + + +class QLlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6, args=None, layer_type=None): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.qargs = args + self.QAct_layernorm_in = QAct_FPout(args, layer_type=layer_type + "_in") + self.QAct_layernorm_out = QAct_FPin(args, layer_type=layer_type + "_out") + + def forward(self, hidden_states, s): + hidden_states = self.QAct_layernorm_in(hidden_states, s) + + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = self.weight * hidden_states.to(input_dtype) + + hidden_states, s = self.QAct_layernorm_out(hidden_states) + return hidden_states, s + + +ALL_LAYERNORM_LAYERS.append(QLlamaRMSNorm) + + +class QMemLlamaMLP(LlamaMLP): + def __init__(self, config, layer_idx): + super().__init__(config) + self.layer_idx = layer_idx + self.gate_proj = QLinear( + self.hidden_size, self.intermediate_size, bias=False, args=config, layer_type="mlp_gate" + ) + self.up_proj = QLinear(self.hidden_size, self.intermediate_size, bias=False, args=config, layer_type="mlp_up") + self.down_proj = QLinear( + self.intermediate_size, self.hidden_size, bias=False, args=config, layer_type="mlp_down" + ) + self.act_fn = ACT2FN[config.hidden_act] + + self.QAct_act_sum = QAct_FPout(config, layer_type="mlp_act_sum") + self.QAct_act_gate = QAct_FPin(config, layer_type="mlp_act_gate") + self.QAct_act_up = QAct_FPin(config, layer_type="mlp_act_up") + + self.QAct_act_in = QAct_FPout(config, layer_type="mlp_act_in") + self.QAct_act_out = QAct_FPin(config, layer_type="mlp_act_out") + + self.QMul_act = QMul(config, layer_type="mul_act") + + def forward(self, x, s): + if self.config.pretraining_tp > 1: + raise ValueError("Currently Quantization is not implemented for tensor parallel for simplicity") + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + # down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + x = self.QAct_act_sum(x, s) + x_gate, s_gate = self.QAct_act_gate(x) + x_up, s_up = self.QAct_act_up(x) + x_gate, s_gate = self.gate_proj(x_gate, s_gate) + x_gate = self.QAct_act_in(x_gate, s_gate) + x_gate = self.act_fn(x_gate) + x_gate, s_gate = self.QAct_act_out(x_gate) + + x_up, s_up = self.up_proj(x_up, s_up) + x, s = self.QMul_act(x_gate, x_up, s_gate, s_up) + down_proj, s = self.down_proj(x, s) + return down_proj, s + + +class QMemLlamaAttention(LlamaAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: QMemLlamaConfig, layer_idx): + super().__init__(config) + self.layer_idx = layer_idx + self.q_proj = QLinear( + self.hidden_size, + self.num_heads * self.head_dim, + bias=config.attention_bias, + args=config, + layer_type="attn_q", + ) + self.k_proj = QLinear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + args=config, + layer_type="attn_k", + ) + self.v_proj = QLinear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + args=config, + layer_type="attn_v", + ) + self.o_proj = QLinear( + self.num_heads * self.head_dim, + self.hidden_size, + bias=config.attention_bias, + args=config, + layer_type="attn_proj", + ) + + self.QAct_qkv_sum = QAct_FPout(config, layer_type="attn_qkv_sum") + + self.QAct_q_in = QAct_FPin(config, layer_type="attn_q_in") + self.QAct_k_in = QAct_FPin(config, layer_type="attn_k_in") + self.QAct_v_in = QAct_FPin(config, layer_type="attn_v_in") + + self.QAct_q_out = QAct_FPout(config, layer_type="attn_q_out") + self.QAct_k_out = QAct_FPout(config, layer_type="attn_k_out") + self.QAct_v_out = QAct_FPout(config, layer_type="attn_v_out") + self.QAct_proj_in = QAct_FPin(config, layer_type="attn_proj_in") + + +class QMemLlamaFlashAttention2(QMemLlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + s: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + hidden_states = self.QAct_qkv_sum(hidden_states, s) + + q, sq = self.QAct_q_in(hidden_states) + k, sk = self.QAct_k_in(hidden_states) + v, sv = self.QAct_v_in(hidden_states) + + query_states, sq = self.q_proj(q, sq) + key_states, sk = self.k_proj(k, sk) + value_states, sv = self.v_proj(v, sv) + + query_states = self.QAct_q_out(query_states, sq) + key_states = self.QAct_k_out(key_states, sk) + value_states = self.QAct_v_out(value_states, sv) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + + attn_output = attn_output.to(torch.float32) + attn_output, s = self.QAct_proj_in(attn_output) + attn_output, s = self.o_proj(attn_output, s) + + if not output_attentions: + attn_weights = None + + return attn_output, s, attn_weights, past_key_value + + +class QMemLlamaSdpaAttention(QMemLlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + s: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + hidden_states = self.QAct_qkv_sum(hidden_states, s) + + q, sq = self.QAct_q_in(hidden_states) + k, sk = self.QAct_k_in(hidden_states) + v, sv = self.QAct_v_in(hidden_states) + + query_states, sq = self.q_proj(q, sq) + key_states, sk = self.k_proj(k, sk) + value_states, sv = self.v_proj(v, sv) + + query_states = self.QAct_q_out(query_states, sq) + key_states = self.QAct_k_out(key_states, sk) + value_states = self.QAct_v_out(value_states, sv) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + # attn_output = attn_output.to(torch.float32) + attn_output, s = self.QAct_proj_in(attn_output) + attn_output, s = self.o_proj(attn_output, s) + + return attn_output, s, None, past_key_value + + +QMemLLAMA_ATTENTION_CLASSES = { + "eager": QMemLlamaAttention, + "flash_attention_2": QMemLlamaFlashAttention2, + "sdpa": QMemLlamaSdpaAttention, +} + + +class QMemLlamaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: QMemLlamaConfig, layer_idx): + super().__init__(config, layer_idx=layer_idx) + self.hidden_size = config.hidden_size + + self.self_attn = QMemLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = QMemLlamaMLP(config, layer_idx) + self.input_layernorm = QLlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, args=config, layer_type="ln_attn" + ) + self.post_attention_layernorm = QLlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, args=config, layer_type="ln_mlp" + ) + + self.QAdd_attn = QAdd(config, layer_type="add_attn") + self.QAdd_mlp = QAdd(config, layer_type="add_mlp") + + self.QAct_reattnout_fx = QAct_FPin(config, layer_type="re_attn_out_fx") + self.QAct_reattnout_re = QAct_FPin(config, layer_type="re_attn_out_re") + + self.QAct_remlpout_fx = QAct_FPin(config, layer_type="re_mlp_out_fx") + self.QAct_remlpout_re = QAct_FPin(config, layer_type="re_mlp_out_re") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual, res = self.QAct_reattnout_re(hidden_states) + hidden_states, s = self.QAct_reattnout_fx(hidden_states) + + hidden_states, s = self.input_layernorm(hidden_states, s) + + # Self Attention + hidden_states, s, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + s=s, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = self.QAdd_attn(residual, hidden_states, res, s) + + # Fully Connected + residual, res = self.QAct_remlpout_re(hidden_states) + hidden_states, s = self.QAct_remlpout_fx(hidden_states) + + hidden_states, s = self.post_attention_layernorm(hidden_states, s) + hidden_states, s = self.mlp(hidden_states, s) + hidden_states = self.QAdd_mlp(residual, hidden_states, res, s) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class QMemLlamaPreTrainedModel(LlamaPreTrainedModel): + config_class = QMemLlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["QMemLlamaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear) or isinstance(module, QLinearTE): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class QMemLlamaModel(QMemLlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: QMemLlamaConfig + """ + + def __init__(self, config: QMemLlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [QMemLlamaDecoderLayer(config, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + _update_causal_mask = LlamaModel._update_causal_mask + forward = LlamaModel.forward + + +class QMemLlamaForCausalLM(QMemLlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = QMemLlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.forward_step_id = 0 + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + forward = LlamaForCausalLM.forward + prepare_inputs_for_generation = LlamaForCausalLM.prepare_inputs_for_generation + + +class QMemLlamaForSequenceClassification(QMemLlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = QMemLlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + forward = LlamaForSequenceClassification.forward + + +AutoConfig.register("qmemllama", QMemLlamaConfig) +AutoModel.register(QMemLlamaConfig, QMemLlamaModel) +AutoModelForCausalLM.register(QMemLlamaConfig, QMemLlamaForCausalLM) diff --git a/llava/model/liger/cross_entropy.py b/llava/model/liger/cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..06274f332b62ac2178904439a18c3474693242a2 --- /dev/null +++ b/llava/model/liger/cross_entropy.py @@ -0,0 +1,417 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +import operator +from typing import Optional + +import torch +import triton +import triton.language as tl + +from .utils import compare_version, element_mul_kernel, is_hip + +if compare_version("triton", operator.ge, "3.0.0"): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import tanh + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import tanh +else: + from triton.language.math import tanh + +_TRUE = tl.constexpr(1) +_FALSE = tl.constexpr(0) + + +@triton.jit +def liger_cross_entropy_kernel( + X_ptr, + X_stride, + Y_ptr, + Y_stride, + loss_ptr, + z_loss_ptr, + loss_stride, + n_cols, + n_non_ignore, + ignore_index, + lse_square_scale: tl.constexpr, + label_smoothing: tl.constexpr, + reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time + softcap, + RETURN_Z_LOSS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + HAS_SOFTCAPPING: tl.constexpr, +): + """ + This kernel computes both cross entropy loss and the gradient of the input. + We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math. + + Parameters: + X_ptr: Pointer to input tensor. + X_stride (int): The stride of the input tensor. + Y_ptr: Pointer to target tensor. + Y_stride (int): The stride of the target tensor. + loss_ptr: Pointer to tensor to store the loss. + z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0. + loss_stride (int): The stride of the loss tensor. + n_cols (int): The number of columns in the input tensor. + n_non_ignore (int): The number of non-ignored elements in the batch. + ignore_index (int): The index to ignore in the target. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. + RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1. + reduction (str): The string for the reduction to apply + softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap). + BLOCK_SIZE (int): The block size for Triton operations. + HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. + """ + + # https://github.com/triton-lang/triton/issues/1058 + # If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64 + program_id = tl.program_id(0).to(tl.int64) + + # 1. Load Y_ptr first because if the target is ignore_index, we can return right away + Y_ptr += program_id * Y_stride + y = tl.load(Y_ptr) + + # 2. locate the start index + X_ptr += program_id * X_stride + + if y == ignore_index: + # set all X_ptr as 0 + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols) + return + + loss_ptr += program_id * loss_stride + z_loss_ptr += program_id * loss_stride + + # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) + # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 + + # 3. [Online softmax] first pass: find max + sum + m = float("-inf") # m is the max value. use the notation from the paper + d = 0.0 # d is the sum. use the notation from the paper + ori_X_y = tl.load(X_ptr + y) # we need to store the original value of X_y for the loss calculation + if HAS_SOFTCAPPING: + ori_X_y = softcap * tanh(ori_X_y / softcap) + + # Label smoothing is a general case of normal cross entropy + # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310 + scaled_x_sum = 0.0 + eps = label_smoothing / n_cols + + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")) + if HAS_SOFTCAPPING: + X_block = softcap * tanh(X_block / softcap) + block_max = tl.max(X_block) + if label_smoothing > 0: + # scale X beforehand to avoid overflow + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) + m_new = tl.maximum(m, block_max) + d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) + m = m_new + + # log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X))))) + # = log (e^(max(X)) * sum(e ^ (X_i - max(X)))) + # = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d + lse = m + tl.log(d) + + # 4. [Online Softmax] Second pass: compute gradients + # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N) + # dx_y = (softmax(x_y) - 1) / N + # dx_i = softmax(x_i) / N, i != y + # For label smoothing: + # dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y + # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N + # = dx_i - (1 - label_smoothing) / N + # With Z loss: + # dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y + # dx_y = dx_i - (1 - label_smoothing) / N + # For 'sum' reduction, no normalization is applied: + # dx_y = softmax(x_y) - 1 + # dx_i = softmax(x_i), for i ≠ y + + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")) + if HAS_SOFTCAPPING: + intermediate = tanh(X_block / softcap) + X_block = softcap * intermediate + # softmax(x_i) + X_block = tl.exp(X_block - m) / d + # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) + X_block += 2 * lse_square_scale * lse * X_block + # smoothing term + X_block += -eps + # special handle dx_y + X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) + # reduction scale + if reduction == "mean": + X_block = X_block / (n_non_ignore) + # chain rule + # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) + if HAS_SOFTCAPPING: + X_block = X_block * (1 - intermediate * intermediate) + + tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) + + # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in + # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34 + tl.debug_barrier() + + # 5. Calculate the loss + + # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) + # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) + # = X_y - m - log d = X_y - lse + # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 + # So we can safely calculate log (softmax(X_y)) without overflow + loss = lse - ori_X_y + + # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps + # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) + # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) + # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: + # = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd)) + # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 + # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 + # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 + if label_smoothing > 0: + smooth_loss = scaled_x_sum + label_smoothing * lse + loss = loss * (1 - label_smoothing) + smooth_loss + + # An auxiliary loss, z_loss + # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html + z_loss = lse_square_scale * lse * lse + loss += z_loss + # Normalize the loss by the number of non-ignored elements if reduction is "mean" + if reduction == "mean": + z_loss = z_loss / n_non_ignore + loss = loss / n_non_ignore + + tl.store(loss_ptr, loss) + if RETURN_Z_LOSS == _TRUE: + tl.store(z_loss_ptr, z_loss) + + +# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 +# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling +# The optimal maximum block size depends on your hardware, your kernel, and your dtype +MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning + + +_bool_to_return_z_loss = { + True: _TRUE.value, + False: _FALSE.value, +} + + +def cross_entropy_forward( + _input, + target, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap, + return_z_loss, +): + if not isinstance(return_z_loss, int): + assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}" + return_z_loss = _bool_to_return_z_loss[return_z_loss] + else: + assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}" + + BT, V = _input.shape + n_rows = BT + + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + # unreduced loss + loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) + if return_z_loss == _TRUE.value: + z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) + else: + z_loss_1d = loss_1d # dummy ptr when return_z_loss == False + + n_non_ignore = (target != ignore_index).sum().item() + + # ensure _input and target are contiguous in the last dimension + if _input.stride(-1) != 1: + _input = _input.contiguous() + if target.stride(-1) != 1: + target = target.contiguous() + + # Here we use a trick to store X_ptr gradient in X_ptr so we can save memory + liger_cross_entropy_kernel[(n_rows,)]( + X_ptr=_input, + X_stride=_input.stride(-2), + Y_ptr=target, + Y_stride=target.stride(-1), # always 1 + loss_ptr=loss_1d, + z_loss_ptr=z_loss_1d, + loss_stride=loss_1d.stride(-1), # always 1 + n_cols=V, + n_non_ignore=n_non_ignore, + ignore_index=ignore_index, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap if softcap is not None else 0.0, + RETURN_Z_LOSS=return_z_loss, + BLOCK_SIZE=BLOCK_SIZE, + HAS_SOFTCAPPING=True if softcap is not None else False, + # TODO: 32 seems to give the best performance + # Performance is quite sensitive to num_warps + num_warps=32 if not is_hip() else 16, + ) + + loss = torch.sum(loss_1d) + if return_z_loss == _TRUE.value: + z_loss = torch.sum(z_loss_1d) + else: + z_loss = None + + return loss, z_loss, _input + + +def cross_entropy_backward(_input, grad_output): + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + pass + + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + else: + BT, V = _input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + element_mul_kernel[(n_rows,)]( + _input, + _input.stride(-2), + grad_output, + V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32 if not is_hip() else 16, + ) + + return _input + + +class LigerCrossEntropyFunction(torch.autograd.Function): + """ + This class implements a custom autograd function for the Liger Cross Entropy loss. + It overrides the forward and backward methods of the torch.autograd.Function class. + """ + + @staticmethod + def forward( + ctx, + _input: torch.Tensor, + target: torch.Tensor, + ignore_index: int = -100, + lse_square_scale: float = 0.0, + label_smoothing: float = 0.0, + reduction: str = "mean", + softcap: Optional[float] = None, + return_z_loss: bool = False, + ): + """ + The forward pass of the Liger Cross Entropy loss. + + Parameters: + ctx : The context object. + _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size. + target (tensor): The target tensor of shape (BT) where each value is in [0, V-1]. + ignore_index (int): The index to ignore in the target. + lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduction (str): The reduction to apply to the output: "none" | "mean | "sum". + softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap). + return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False` + + Returns: + tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None. + """ + loss, z_loss, _input = cross_entropy_forward( + _input, + target, + ignore_index, + lse_square_scale, + label_smoothing, + reduction, + softcap, + return_z_loss, + ) + # TODO: investigation + # If we don't detach the _input tensor, the memory will double + # Not sure why but seems that there will be a time both grad and value exist but in different location + ctx.save_for_backward(_input.detach()) + ctx.return_z_loss = return_z_loss + + return loss, z_loss + + @staticmethod + def backward(ctx, grad_output, grad_ouput2): + """ + The backward pass of the Liger Cross Entropy loss. + + Parameters: + ctx : The context object with saved tensors. + grad_output (tensor): The tensor containing the gradient of the loss with respect to the output. + grad_output2 (tenosr): No use. + Returns: + tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. + """ + if ctx.return_z_loss: + del grad_ouput2 # z_loss is only for logging + + (_input,) = ctx.saved_tensors + _input = cross_entropy_backward(_input, grad_output) + return ( + _input, + None, + None, + None, + None, + None, + None, + None, + ) + + +def liger_fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs): + reduction = "sum" if num_items_in_batch is not None else "mean" + # loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction) + loss, _ = LigerCrossEntropyFunction.apply(source, target, ignore_index, 0.0, 0.0, reduction) + if reduction == "sum": + loss = loss / num_items_in_batch + return loss + + +def LigerForCausalLMLoss( + logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs +): + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # Flatten the tokens + shift_logits = shift_logits.view(-1, vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = liger_fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs) + return loss diff --git a/llava/model/liger/utils.py b/llava/model/liger/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5290b98c237f43757d0ed74d91f8f3bf6a807eeb --- /dev/null +++ b/llava/model/liger/utils.py @@ -0,0 +1,129 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +""" +This file incorporates code from Unsloth licensed under the Apache License, Version 2.0. +See the original Unsloth repository at https://github.com/unslothai/unsloth. + +The following line +https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23 +is based on code from Unsloth, located at: +https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 + +Modifications made by Yanning Chen, 2024. +""" + +import functools +import importlib +import operator +from typing import Callable + +import torch +import triton +import triton.language as tl +from packaging.version import Version + + +def is_hip() -> bool: + return torch.version.hip is not None + + +def ensure_contiguous(fn): + @functools.wraps(fn) + def wrapper(ctx, *args, **kwargs): + def maybe_to_contiguous(x): + return x.contiguous() if isinstance(x, torch.Tensor) else x + + args = [maybe_to_contiguous(arg) for arg in args] + kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()} + return fn(ctx, *args, **kwargs) + + return wrapper + + +def calculate_settings(n): + # reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 + + MAX_FUSED_SIZE = 65536 + BLOCK_SIZE = triton.next_power_of_2(n) + if BLOCK_SIZE > MAX_FUSED_SIZE: + raise RuntimeError( + f"Cannot launch Triton kernel since n = {n} exceeds " + f"the recommended Triton blocksize = {MAX_FUSED_SIZE}." + ) + + num_warps = 4 + if BLOCK_SIZE >= 32768: + num_warps = 32 if not is_hip() else 16 + elif BLOCK_SIZE >= 8192: + num_warps = 16 + elif BLOCK_SIZE >= 2048: + num_warps = 8 + return BLOCK_SIZE, num_warps + + +def compare_version(package: str, operator: Callable, target: str): + try: + pkg = importlib.import_module(package) + except ImportError: + return False + pkg_version = Version(pkg.__version__) + return operator(pkg_version, Version(target)) + + +def get_amp_custom_fwd_bwd() -> Callable: + if compare_version("torch", operator.ge, "2.4.0"): + return ( + functools.partial(torch.amp.custom_fwd, device_type="cuda"), + functools.partial(torch.amp.custom_bwd, device_type="cuda"), + ) + return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd + + +amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd() + + +torch_to_triton_dtype = { + torch.float32: tl.float32, + torch.float16: tl.float16, + torch.bfloat16: tl.bfloat16, +} + + +@triton.jit +def element_mul_kernel( + X_ptr, + X_stride, + grad_output_ptr, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. + The multiplication is performed in-place on the tensor pointed by X_ptr. + + Parameters: + X_ptr: Pointer to the input tensor. + X_stride (int): The stride of the input tensor. + grad_output_ptr: Pointer to the gradient output value. + n_cols (int): The number of columns in the input tensor. + BLOCK_SIZE (int): The block size for Triton operations. + """ + + # Get the program ID and convert it to int64 to avoid overflow + program_id = tl.program_id(0).to(tl.int64) + + # Locate the start index + X_ptr += program_id * X_stride + + # Load the gradient output value + grad_output = tl.load(grad_output_ptr) + + # Perform the element-wise multiplication + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) + tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) diff --git a/llava/model/llava_arch.py b/llava/model/llava_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..a6030ea9c5b2f29b4a26538bfe7232fdd69e0138 --- /dev/null +++ b/llava/model/llava_arch.py @@ -0,0 +1,861 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. +# LICENSE is in incl_licenses directory. + +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import json +import logging +import os +import os.path as osp +import warnings +from abc import ABC +from collections import OrderedDict, defaultdict, deque +from itertools import chain +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from einops import rearrange +from hydra.utils import instantiate +from transformers import AutoConfig, GenerationConfig, LogitsProcessor, PreTrainedModel +from transformers.modeling_utils import ContextManagers, no_init_weights + +from llava.constants import DEFAULT_SOUND_TOKEN,DEFAULT_SPEECH_TOKEN, IGNORE_INDEX, NUM_EXTRA_TOKENS +from llava.mm_utils import process_image, process_images, process_sounds,process_sound_masks +from llava.model.configuration_llava import LlavaConfig, ResponseFormat +from llava.model.language_model.builder import build_llm_and_tokenizer +from llava.model.multimodal_encoder.builder import build_sound_tower +from llava.model.multimodal_projector.builder import build_speech_mm_projector, build_sound_mm_projector +from llava.model.utils import get_model_config +from llava.train.sequence_parallel import get_pg_manager +from llava.utils import distributed +from llava.utils.media import extract_media +from llava.utils.tokenizer import tokenize_conversation + + +class LlavaMetaModel(ABC): + def _init_llm(self, llm_cfg, config, *args, **kwargs): + llm, tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs) + return llm, tokenizer + + def init_vlm(self, config: PreTrainedModel = None, *args, **kwargs): + # TODO(ligeng): figure out how from_config and from_pretrained works in HF implementation. + if hasattr(self, "llm") or hasattr(self, "vision_tower") or hasattr(self, "speech_tower") or hasattr(self, "sound_tower") or hasattr(self, "mm_projector") or hasattr(self, "speech_mm_projector") or hasattr(self, "sound_mm_projector"): + # already initialized, skipped + return + + model_dtype = getattr(config, "model_dtype", "torch.float16") + if not hasattr(config, "model_dtype"): + warnings.warn("model_dtype not found in config, defaulting to torch.float16.") + config.model_dtype = model_dtype + + cfgs = get_model_config(config) + print(cfgs) + if len(cfgs) == 7: + llm_cfg, vision_tower_cfg, speech_tower_cfg,sound_tower_cfg, mm_projector_cfg, speech_mm_projector_cfg,sound_mm_projector_cfg = cfgs + else: + raise ValueError("`llm_cfg` `mm_projector_cfg` `speech_mm_projector_cfg` `sound_mm_projector_cfg` `vision_tower_cfg` `speech_tower_cfg` `sound_tower_cfg` not found in the config.") + + self.llm, self.tokenizer = self._init_llm(llm_cfg, config, *args, **kwargs) + + self.sound_tower = build_sound_tower(sound_tower_cfg, config) + self.sound_mm_projector = build_sound_mm_projector(sound_mm_projector_cfg, config) + + if isinstance(self.config, dict): + self.vocab_size = config.llm_cfg["vocab_size"] + NUM_EXTRA_TOKENS + else: + self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS + logging.info( + f"[XGrammar] config is not a dict, loading vocab size from tokenizer {self.tokenizer.vocab_size} + {NUM_EXTRA_TOKENS} => {self.vocab_size}" + ) + + # XGrammar tokenizer and grammar compiler + # lazy init only when specified json output during inference + self.grammar_compiler = None + + self.encoders = {} + for name in ["sound"]: + config = getattr(self.config, f"{name}_encoder") + if isinstance(config, str): + config = json.loads(config) + self.encoders[name] = instantiate(config, parent=self) + + self.post_config() + self.is_loaded = True + + assert ( + self.llm is not None or self.vision_tower is not None or self.speech_tower is not None or self.mm_projector is not None or self.speech_mm_projector is not None + ), "At least one of the components must be instantiated." + + @classmethod + def load_from_config(cls, model_path_or_config, *args, **kwargs): + pass + + ## FIXME we will use this function to load model in the future + @classmethod + def load_pretrained(cls, model_path_or_config, *args, **kwargs): + kwargs.pop("config", None) + + if isinstance(model_path_or_config, str): + config = AutoConfig.from_pretrained(model_path_or_config) + elif isinstance(model_path_or_config, LlavaConfig): + config = model_path_or_config + else: + raise NotImplementedError( + f"wrong type, {type(model_path_or_config)} \ + {isinstance(model_path_or_config, LlavaConfig)}" + ) + + model_dtype = getattr(config, "model_dtype", "torch.float16") + if not hasattr(config, "model_dtype"): + warnings.warn("model_dtype not found in config, defaulting to torch.float16.") + config.model_dtype = model_dtype + + cfgs = get_model_config(config) + if len(cfgs) == 7: + llm_cfg, vision_tower_cfg, speech_tower_cfg,sound_tower_cfg, mm_projector_cfg, speech_mm_projector_cfg,sound_mm_projector_cfg = cfgs + else: + raise ValueError("`llm_cfg` `mm_projector_cfg` `speech_mm_projector_cfg` `sound_mm_projector_cfg` `vision_tower_cfg` `speech_tower_cfg` `sound_tower_cfg` not found in the config.") + + init_context = [ + no_init_weights(_enable=True), + ] + + with ContextManagers(init_context): + vlm = cls(config, *args, **kwargs) + + if hasattr(vlm, "llm") or hasattr(vlm, "vision_tower") or hasattr(vlm, "speech_tower") or hasattr(vlm, "sound_tower") or hasattr(vlm, "mm_projector") or hasattr(vlm, "speech_mm_projector") or hasattr(vlm, "sound_mm_projector"): + if vlm.is_loaded: + return vlm + + vlm.llm, vlm.tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs) + vlm.sound_tower = build_sound_tower(sound_tower_cfg, config) + vlm.sound_mm_projector = build_sound_mm_projector(sound_mm_projector_cfg, config) + + self.post_config() + self.is_loaded = True + + # FIXME(ligeng, yunhao): llm should never be none here. + assert ( + vlm.llm is not None or vlm.vision_tower is not None or vlm.speech_tower is not None or vlm.mm_projector is not None or vlm.speech_mm_projector is not None + ), "At least one of the components must be instantiated." + return vlm + + ## FIXME we will use this function to save the model in the future + def save_pretrained(self, output_dir, state_dict=None): + if state_dict is None: + # other wise fetch from deepspeed + # state_dict = accelerator.get_state_dict(is_deepspeed_enabled) + state_dict = self.state_dict() + + if getattr(self, "tokenizer", None): + self.tokenizer.save_pretrained(osp.join(output_dir, "llm")) + + if self.get_llm(): + print(f"saving llm to {osp.join(output_dir, 'llm')}") + self.llm.config._name_or_path = osp.join(output_dir, "llm") + llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k}) + self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict) + self.config.llm_cfg = self.llm.config + + + if self.get_sound_tower(): + print(f"saving sound_tower to {osp.join(output_dir, 'sound_tower')}") + self.sound_tower.config._name_or_path = osp.join(output_dir, "sound_tower") + sound_tower_state_dict = OrderedDict( + {k.split("sound_tower.sound_tower.")[-1]: v for k, v in state_dict.items() if "sound_tower" in k} + ) + self.sound_tower.sound_tower.save_pretrained( + os.path.join(output_dir, "sound_tower"), + state_dict=sound_tower_state_dict, + ) + self.config.sound_tower_cfg = self.sound_tower.config + + if self.get_sound_mm_projector(): + print(f"saving sound_mm_projector to {osp.join(output_dir, 'sound_mm_projector')}") + self.sound_mm_projector.config._name_or_path = osp.join(output_dir, "sound_mm_projector") + sound_mm_projector_state_dict = OrderedDict( + {k.split("sound_mm_projector.")[-1]: v for k, v in state_dict.items() if "sound_mm_projector" in k} + ) + self.sound_mm_projector.save_pretrained( + os.path.join(output_dir, "sound_mm_projector"), + state_dict=sound_mm_projector_state_dict, + ) + self.config.sound_mm_projector_cfg = self.sound_mm_projector.config + + ## update and save top-level config + self.config._name_or_path = output_dir + self.config.architectures = [self.__class__.__name__] + self.config.save_pretrained(output_dir) + + def get_llm(self): + llm = getattr(self, "llm", None) + if type(llm) is list: + llm = llm[0] + return llm + + def get_lm_head(self): + lm_head = getattr(self.get_llm(), "lm_head", None) + return lm_head + + def get_sound_tower(self): + sound_tower = getattr(self, "sound_tower", None) + if type(sound_tower) is list: + sound_tower = sound_tower[0] + return sound_tower + + + def get_sound_mm_projector(self): + sound_mm_projector = getattr(self, "sound_mm_projector", None) + if type(sound_mm_projector) is list: + sound_mm_projector = sound_mm_projector[0] + return sound_mm_projector + + def post_config(self): + self.training = self.get_llm().training + ## configuration + if getattr(self.config, "llm_cfg", None) is None: + self.config.llm_cfg = self.llm.config + self.config.speech_tower_cfg = self.speech_tower.config + if getattr(self.config, "sound_tower_cfg", None) is None: + self.config.sound_tower_cfg = self.sound_tower.config + if getattr(self.config, "sound_mm_projector_cfg", None) is None: + self.config.sound_mm_projector_cfg = self.sound_mm_projector.config + + def freezed_module_patch(self): + """ + Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules. + """ + if self.training: + if self.get_llm() and not getattr(self.config, "tune_language_model", False): + pass + + if self.get_sound_tower() and not getattr(self.config, "tune_sound_tower", False): + self.get_sound_tower().eval() + if self.get_sound_mm_projector() and not getattr(self.config, "tune_sound_mm_projector", False): + self.get_sound_mm_projector().eval() + + + def encode_sound(self, sounds, masks=None): + + sound_features = self.get_sound_tower()(sounds, masks) + sound_features = self.get_sound_mm_projector()(sound_features) + return sound_features + + ## @yunhao: is there a better way to handle function call and attributes for llm? + ## support beam search + def _temporary_reorder_cache(self, past_key_values, sorted_idx): + return self.get_llm()._temporary_reorder_cache(past_key_values, sorted_idx) + + def get_input_embeddings(self): + return self.get_llm().get_input_embeddings() + + def get_output_embeddings(self): + return self.get_llm().get_output_embeddings() + + def resize_token_embeddings(self, embed_size): + self.get_llm().resize_token_embeddings(embed_size) + + +class LlavaMetaForCausalLM(ABC): + def _embed( + self, + input_ids: torch.Tensor, + media: Dict[str, List[torch.Tensor]], + media_config: Dict[str, Dict[str, Any]], + labels: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor], + media_meta: Dict[str, Dict[str, Any]]= None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + labels = labels if labels is not None else torch.full_like(input_ids, IGNORE_INDEX) + attention_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids, dtype=torch.bool) + + PROCESS_GROUP_MANAGER = get_pg_manager() + if PROCESS_GROUP_MANAGER is not None: + for name in media: + self.encoders[name].end_tokens = None + + # Extract text and media embeddings + text_embeds = self.llm.model.embed_tokens(input_ids) + media_embeds = self.__embed_media_tokens(media, media_config, media_meta) + # This is a workaround to make sure the dummy embeddings are consumed + while media_embeds.get("dummy"): + dummy_embed = media_embeds["dummy"].popleft() + text_embeds += torch.sum(dummy_embed) * 0 + # Remove padding + batch_size = labels.shape[0] + + # Build inverse mapping from token ID to media name + media_tokens = {} + for name, token_id in self.tokenizer.media_token_ids.items(): + media_tokens[token_id] = name + + # -------------------------------- # + num_audio_tokens = torch.stack(media_meta["sound_embed_masks"], dim=0).sum(-1) + num_audio_tokens = torch.tensor([round(int(x) / 10) * 10 for x in num_audio_tokens]) + num_audios = len(media_embeds['sound']) # length of queue is the number of audios we have in total + max_audio_tokens, embed_dim = media_embeds['sound'][0].shape + + + audio_features_mask = torch.arange(max_audio_tokens).expand(num_audios, max_audio_tokens).to( + num_audio_tokens.device + ) < num_audio_tokens.unsqueeze(1) + + audio_embeds = [] + while media_embeds['sound']: + audio_embeds.append(media_embeds['sound'].popleft()) + audio_embeds = torch.stack(audio_embeds,dim=0) + + masked_audio_features = audio_embeds[audio_features_mask].view(-1, embed_dim) + batch_size, sequence_length = input_ids.shape + _left_padding = torch.any(attention_mask[:, 0] == 0) + _right_padding = torch.any(attention_mask[:, -1] == 0) + + left_padding = True + if batch_size > 1: + if _left_padding and not _right_padding: + left_padding = True + elif not _left_padding and _right_padding: + left_padding = False + elif not _left_padding and not _right_padding: + # both side is 1, so cannot tell + left_padding = self.tokenizer.padding_side == "left" + else: + # invalid attention_mask + raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}") + + # 1. Create a mask to know where special audio tokens are + special_audio_token_mask = input_ids == self.tokenizer.media_token_ids['sound'] #hard coded to just work with 'sound' + num_special_audio_tokens = torch.sum(special_audio_token_mask, dim=-1) + + # In case the Audio model or the Language model has been offloaded to CPU, we need to manually + # set the corresponding tensors into their correct target device. + target_device = text_embeds.device + attention_mask = attention_mask.to(target_device) + input_ids = input_ids.to(target_device) + num_audio_tokens = num_audio_tokens.to(target_device) + batch_indices, non_audio_indices = torch.where( + (input_ids != self.tokenizer.media_token_ids['sound']) & (attention_mask == 1) + ) + + # 2. Compute the positions where text should be written + # Calculate new positions for text tokens in merged audio-text sequence. + # `special_audio_token_mask` identifies audio tokens. Each audio token will be replaced by `audio_feat_lengths - 1` text tokens. + # `torch.cumsum` computes how each audio token shifts subsequent text token positions. + token_placeholder_num = torch.zeros_like(input_ids) + token_placeholder_num[special_audio_token_mask] = num_audio_tokens.long() - 1 + token_placeholder_num = token_placeholder_num + 1 + new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1 + max_token_num = token_placeholder_num.sum(-1).max() + nb_audio_pad = max_token_num - 1 - new_token_positions[:, -1] + if left_padding: + new_token_positions += nb_audio_pad[:, None] # offset for left padding + text_to_overwrite = new_token_positions[batch_indices, non_audio_indices] + batch_indices, non_audio_indices, text_to_overwrite = ( + batch_indices.to(target_device), + non_audio_indices.to(target_device), + text_to_overwrite.to(target_device), + ) + + # 3. Create the full embedding, already padded to the maximum position + final_embedding = torch.zeros( + batch_size, max_token_num, embed_dim, dtype=text_embeds.dtype, device=text_embeds.device + ) + final_attention_mask = torch.zeros( + batch_size, max_token_num, dtype=attention_mask.dtype, device=text_embeds.device + ) + final_input_ids = torch.full( + (batch_size, max_token_num), self.tokenizer.pad_token_id, dtype=input_ids.dtype, device=text_embeds.device + ) + + # 4. Fill the embeddings based on the mask. If we have ["hey" "