+
+## Overview
+
+This repo contains the PyTorch implementation of [Audio Flamingo 3: Advancing Audio Intelligence with Fully Open Large Audio-Language Models](). Audio Flamingo 3 (AF3) is a fully open, state-of-the-art Large Audio-Language Model (LALM) that advances reasoning and understanding across speech, sounds, and music. AF3 builds on previous work with innovations in:
+
+- Unified audio representation learning (speech, sound, music)
+- Flexible, on-demand chain-of-thought reasoning (Thinking in Audio)
+- Long-context audio comprehension (including speech and up to 10 minutes)
+- Multi-turn, multi-audio conversational dialogue (AF3-Chat)
+- Voice-to-voice interaction (AF3-Chat)
+
+Extensive evaluations confirm AF3’s effectiveness, setting new benchmarks on over 20 public audio understanding and reasoning tasks.
+
+
+## Main Results
+
+Audio Flamingo 3 outperforms prior SOTA models including GAMA, Audio Flamingo, Audio Flamingo 2, Qwen-Audio, Qwen2-Audio, Qwen2.5-Omni.LTU, LTU-AS, SALMONN, AudioGPT, Gemini Flash v2 and Gemini Pro v1.5 on a number of understanding and reasoning benchmarks.
+
+
+
+
+
+
+
+
+
+## Audio Flamingo 3 Architecture
+
+Audio Flamingo 3 uses AF-Whisper unified audio encoder, MLP-based audio adaptor, Decoder-only LLM backbone (Qwen2.5-7B), and Streaming TTS module (AF3-Chat).
+Audio Flamingo 3 can take up to 10 minutes of audio inputs.
+
+
+
+
+
+## Installation
+
+```bash
+./environment_setup.sh af3
+```
+
+## Code Structure
+
+- The folder ```audio_flamingo_3/``` contains the main training and inference code of Audio Flamingo 3.
+- The folder ```audio_flamingo_3/scripts``` contains the inference scripts of Audio Flamingo 3 in case you would like to use our pretrained checkpoints on HuggingFace.
+
+Each folder is self-contained and we expect no cross dependencies between these folders. This repo does not contain the code for Streaming-TTS pipeline which will released in the near future.
+
+## Single Line Inference
+
+To infer stage 3 model directly, run the command below:
+```bash
+python llava/cli/infer_audio.py --model-base /path/to/checkpoint/af3-7b --conv-mode auto --text "Please describe the audio in detail" --media static/audio1.wav
+```
+
+To infer the model in stage 3.5 model, run the command below:
+```bash
+python llava/cli/infer_audio.py --model-base /path/to/checkpoint/af3-7b --model-path /path/to/checkpoint/af3-7b/stage35 --conv-mode auto --text "Please describe the audio in detail" --media static/audio1.wav --peft-mode
+```
+
+## References
+
+The main training and inferencing code within each folder are modified from [NVILA](https://github.com/NVlabs/VILA/tree/main) [Apache license](incl_licenses/License_1.md).
+
+## License
+
+- The code in this repo is under [MIT license](incl_licenses/MIT_license.md).
+- The checkpoints are for non-commercial use only [NVIDIA OneWay Noncommercial License](incl_licenses/NVIDIA_OneWay_Noncommercial_License.docx). They are also subject to the [Qwen Research license](https://huggingface.co/Qwen/Qwen2.5-7B/blob/main/LICENSE), the [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and the original licenses accompanying each training dataset.
+- Notice: Audio Flamingo 3 is built with Qwen-2.5. Qwen is licensed under the Qwen RESEARCH LICENSE AGREEMENT, Copyright (c) Alibaba Cloud. All Rights Reserved.
+
+
+## Citation
+
+- Audio Flamingo 2
+```
+@article{ghosh2025audio,
+ title={Audio Flamingo 2: An Audio-Language Model with Long-Audio Understanding and Expert Reasoning Abilities},
+ author={Ghosh, Sreyan and Kong, Zhifeng and Kumar, Sonal and Sakshi, S and Kim, Jaehyeon and Ping, Wei and Valle, Rafael and Manocha, Dinesh and Catanzaro, Bryan},
+ journal={arXiv preprint arXiv:2503.03983},
+ year={2025}
+}
+```
+
+- Audio Flamingo
+```
+@inproceedings{kong2024audio,
+ title={Audio Flamingo: A Novel Audio Language Model with Few-Shot Learning and Dialogue Abilities},
+ author={Kong, Zhifeng and Goel, Arushi and Badlani, Rohan and Ping, Wei and Valle, Rafael and Catanzaro, Bryan},
+ booktitle={International Conference on Machine Learning},
+ pages={25125--25148},
+ year={2024},
+ organization={PMLR}
+}
+```
+
+
diff --git a/llava/__init__.py b/llava/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c14ac15313d1fceee30d9a3ef2d69bbba5702722
--- /dev/null
+++ b/llava/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from .entry import *
+from .media import *
diff --git a/llava/cli/infer_audio.py b/llava/cli/infer_audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..8041b2ab9aa52fc166c87ce4527ad5a8356857da
--- /dev/null
+++ b/llava/cli/infer_audio.py
@@ -0,0 +1,88 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import argparse
+import importlib.util
+import json
+import os
+
+from pydantic import BaseModel
+from termcolor import colored
+
+import llava
+from llava import conversation as clib
+from llava.media import Image, Video, Sound
+from llava.model.configuration_llava import JsonSchemaResponseFormat, ResponseFormat
+from peft import PeftModel
+import torch
+
+def get_schema_from_python_path(path: str) -> str:
+ schema_path = os.path.abspath(path)
+ spec = importlib.util.spec_from_file_location("schema_module", schema_path)
+ schema_module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(schema_module)
+
+ # Get the Main class from the loaded module
+ Main = schema_module.Main
+ assert issubclass(
+ Main, BaseModel
+ ), f"The provided python file {path} does not contain a class Main that describes a JSON schema"
+ return Main.schema_json()
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-base", "-mb", type=str, required=True)
+ parser.add_argument("--model-path", "-mp", type=str, required=True)
+ parser.add_argument("--conv-mode", "-c", type=str, default="auto")
+ parser.add_argument("--text", type=str)
+ parser.add_argument("--media", type=str, nargs="+")
+ parser.add_argument("--json-mode", action="store_true")
+ parser.add_argument("--peft-mode", action="store_true")
+ parser.add_argument("--json-schema", type=str, default=None)
+ args = parser.parse_args()
+
+ # Convert json mode to response format
+ if not args.json_mode:
+ response_format = None
+ elif args.json_schema is None:
+ response_format = ResponseFormat(type="json_object")
+ else:
+ schema_str = get_schema_from_python_path(args.json_schema)
+ print(schema_str)
+ response_format = ResponseFormat(type="json_schema", json_schema=JsonSchemaResponseFormat(schema=schema_str))
+
+ # Load model
+ model = llava.load(args.model_base)
+ if args.peft_mode:
+ model = PeftModel.from_pretrained(
+ model,
+ args.model_path,
+ device_map="auto",
+ torch_dtype=torch.float16,
+ )
+ # Set conversation mode
+ clib.default_conversation = clib.conv_templates[args.conv_mode].copy()
+
+ # Prepare multi-modal prompt
+ prompt = []
+ if args.media is not None:
+ for media in args.media or []:
+ if any(media.endswith(ext) for ext in [".wav",".mp3", ".flac"]):
+ media = Sound(media)
+ else:
+ raise ValueError(f"Unsupported media type: {media}")
+ prompt.append(media)
+ if args.text is not None:
+ prompt.append(args.text)
+
+ # Generate response
+ response = model.generate_content(prompt, response_format=response_format)
+ print(colored(response, "cyan", attrs=["bold"]))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llava/constants.py b/llava/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4e557b88bf7437b49bb88c754d71218ed79dfae
--- /dev/null
+++ b/llava/constants.py
@@ -0,0 +1,55 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+
+CONTROLLER_HEART_BEAT_EXPIRATION = 30
+WORKER_HEART_BEAT_INTERVAL = 15
+
+LOGDIR = "."
+
+# Model Constants
+IGNORE_INDEX = -100
+DEFAULT_SOUND_TOKEN = ""
+DEFAULT_SPEECH_TOKEN = ""
+SENTINEL_TOKEN = ""
+
+MEDIA_TOKENS = {
+ "speech": "",
+ "sound": "",
+}
+
+
+"""
+151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
+151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
+151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
+151646: AddedToken("[BOS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
+151647: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
+151648: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
+151649: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
+151650: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
+151651: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
+151652: AddedToken("", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
+
+"""
+NUM_EXTRA_TOKENS = 10
diff --git a/llava/conversation.py b/llava/conversation.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b42e5bf88571231b262040d3601e63006e125e9
--- /dev/null
+++ b/llava/conversation.py
@@ -0,0 +1,197 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+
+import dataclasses
+from enum import Enum, auto
+from typing import List
+
+from llava.utils.logging import logger
+
+
+class SeparatorStyle(Enum):
+ """Different separator style."""
+
+ AUTO = auto()
+ TWO = auto()
+ MPT = auto()
+ PLAIN = auto()
+ LLAMA_3 = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ """A class that keeps all conversation history."""
+
+ system: str
+ roles: List[str]
+ messages: List[List[str]]
+ sep_style: SeparatorStyle = SeparatorStyle.AUTO
+ sep: str = "###"
+ sep2: str = None
+ version: str = "Unknown"
+
+ def get_prompt(self):
+ messages = self.messages
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
+ messages = self.messages.copy()
+ init_role, init_msg = messages[0].copy()
+ init_msg = init_msg[0].replace("", "").strip()
+ messages[0] = (init_role, "\n" + init_msg)
+
+ if self.sep_style == SeparatorStyle.TWO:
+ seps = [self.sep, self.sep2]
+ ret = self.system + seps[0]
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + seps[i % 2]
+ else:
+ ret += role + ":"
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
+ ret = self.system + self.sep
+ for rid, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message = message[0]
+ sep = self.sep if rid < len(messages) - 1 else self.sep2
+ ret += role + message + sep
+ else:
+ ret += role
+ elif self.sep_style == SeparatorStyle.MPT:
+ ret = self.system + self.sep
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + message + self.sep
+ else:
+ ret += role
+ elif self.sep_style == SeparatorStyle.PLAIN:
+ seps = [self.sep, self.sep2]
+ ret = self.system
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += message + seps[i % 2]
+ else:
+ ret += ""
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ return ret
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def copy(self):
+ return Conversation(
+ system=self.system,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ sep_style=self.sep_style,
+ sep=self.sep,
+ sep2=self.sep2,
+ version=self.version,
+ )
+
+
+conv_auto = Conversation(
+ system="",
+ roles=("", ""),
+ messages=(),
+ sep_style=SeparatorStyle.AUTO,
+ sep="\n",
+)
+
+conv_vicuna_v1 = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
+ roles=("USER", "ASSISTANT"),
+ version="v1",
+ messages=(),
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+conv_llava_plain = Conversation(
+ system="",
+ roles=("", ""),
+ messages=(),
+ sep_style=SeparatorStyle.PLAIN,
+ sep="\n",
+)
+
+hermes_2 = Conversation(
+ system="<|im_start|>system\nAnswer the questions.",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ sep_style=SeparatorStyle.MPT,
+ sep="<|im_end|>",
+ messages=(),
+ version="hermes-2",
+)
+
+# Template added by Yukang. Note (kentang-mit@): sep is <|eot_id|> for official template.
+llama_3_chat = Conversation(
+ system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. "
+ "You are able to understand the visual content that the user provides, "
+ "and assist the user with a variety of tasks using natural language.",
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
+ version="llama_v3",
+ messages=(),
+ sep_style=SeparatorStyle.LLAMA_3,
+ sep="<|eot_id|>",
+ sep2="<|end_of_text|>",
+)
+
+
+default_conversation = conv_auto
+conv_templates = {
+ "auto": conv_auto,
+ "hermes-2": hermes_2,
+ "llama_3": llama_3_chat,
+ "v1": conv_vicuna_v1,
+ "vicuna_v1": conv_vicuna_v1,
+ "plain": conv_llava_plain,
+}
+
+
+CONVERSATION_MODE_MAPPING = {
+ "vila1.5-3b": "vicuna_v1",
+ "vila1.5-8b": "llama_3",
+ "vila1.5-13b": "vicuna_v1",
+ "vila1.5-40b": "hermes-2",
+ "llama-3": "llama_3",
+ "llama3": "llama_3",
+}
+
+
+def auto_set_conversation_mode(model_name_or_path: str) -> str:
+ global default_conversation
+ for k, v in CONVERSATION_MODE_MAPPING.items():
+ if k in model_name_or_path.lower():
+ logger.info(f"Setting conversation mode to `{v}` based on model name/path `{model_name_or_path}`.")
+ default_conversation = conv_templates[v]
+ return
diff --git a/llava/data/__init__.py b/llava/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..22d60dab57c08f90e224f3d95047c5dfbf14de6d
--- /dev/null
+++ b/llava/data/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from .builder import *
+from .dataset import *
+from .datasets_mixture import *
diff --git a/llava/data/base.py b/llava/data/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..925a500cf78667b3ab862a72ab1b564e86c2c746
--- /dev/null
+++ b/llava/data/base.py
@@ -0,0 +1,95 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import random
+from typing import Any, Dict, List
+
+import torch
+from torch.utils.data import Dataset
+from transformers import PreTrainedTokenizer
+
+from llava.mm_utils import dynamic_process_images_and_prompt, dynamic_s2_process_images_and_prompt, process_images
+from llava.train.args import DataArguments
+from llava.utils.logging import logger
+from llava.utils.media import extract_media
+from llava.utils.tokenizer import preprocess_conversation
+
+__all__ = ["BaseDataset"]
+
+def _process_speech(speech: List[Any], data_args: DataArguments) -> torch.Tensor:
+ return torch.tensor(speech)
+
+def _process_sound(sound: List[Any], data_args: DataArguments) -> torch.Tensor:
+ return torch.tensor(sound)
+
+def _process_sound_masks(sound_masks: List[Any], data_args: DataArguments) -> torch.Tensor:
+ return torch.tensor(sound_masks)
+
+
+class BaseDataset(Dataset):
+ def __init__(
+ self,
+ tokenizer: PreTrainedTokenizer,
+ data_args: DataArguments,
+ no_system_prompt: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__()
+ self.tokenizer = tokenizer
+ self.data_args = data_args
+ self.no_system_prompt = no_system_prompt
+ self.instances = []
+ self.enable_dynamic_res = False
+ self.enable_dynamic_res_s2 = False
+ # global_batch_size: int,
+ self.global_batch_size = kwargs.get("global_batch_size", 1)
+
+ # by default, dataset cls will resample on failure
+ self.resample_on_failure = kwargs.get("resample_on_failure", True)
+
+ # by default, dataset cls will resample on failure
+ self.resample_on_failure = kwargs.get("resample_on_failure", True)
+
+ def process(self, instance: Dict[str, Any]) -> List[Dict[str, Any]]:
+ raise NotImplementedError
+
+ def __getitem__(self, index: int) -> Dict[str, Any]:
+ instance = self.instances[index]
+
+ try:
+ # Process instance to conversation
+ conversation = self.process(instance)
+
+ # Extract media from conversation
+ media, media_meta = extract_media(conversation, self.data_args)
+
+ if "speech" in media:
+ processed_speech = _process_speech(media["speech"], self.data_args)
+ if "sound" in media:
+ processed_sound = _process_sound(media["sound"], self.data_args)
+ processed_sound_feature_masks = _process_sound_masks(media_meta["sound_feature_masks"], self.data_args)
+ processed_sound_embed_masks = _process_sound_masks(media_meta["sound_embed_masks"], self.data_args)
+ # Prepare "input_ids" and "labels" for training
+ data = preprocess_conversation(conversation, self.tokenizer, no_system_prompt=self.no_system_prompt)
+
+ if "speech" in media:
+ data["speech"] = processed_speech
+ if "sound" in media:
+ data["sound"] = processed_sound
+ data["sound_feature_masks"] = processed_sound_feature_masks
+ data["sound_embed_masks"] = processed_sound_embed_masks
+
+ except Exception as e:
+ if not self.resample_on_failure:
+ raise e
+ else:
+ logger.exception(f"Error processing instance '{instance}': '{e}'. Resampling.")
+ return self.__getitem__(random.randint(0, len(self.instances) - 1))
+
+ return data
+
+ def __len__(self) -> int:
+ return len(self.instances)
diff --git a/llava/data/builder.py b/llava/data/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..871d380c089e5948a58a668846ecf96faaee4afa
--- /dev/null
+++ b/llava/data/builder.py
@@ -0,0 +1,193 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import os
+import os.path as osp
+from itertools import chain
+from typing import Any, List, Optional
+
+import torch
+import torch.distributed as dist
+from hydra.utils import instantiate
+from torch.utils.data import ConcatDataset, Dataset
+from transformers import PreTrainedTokenizer
+
+from llava.data.datasets_mixture import DATASETS_LEGACY
+from llava.train.args import DataArguments, TrainingArguments
+from llava.utils import io
+from llava.utils.logging import logger
+import time
+import numpy as np
+__all__ = ["DATASETS", "MIXTURES", "register_datasets", "register_mixtures", "parse_mixture", "build_dataset"]
+
+
+def load_dataset_yaml(name):
+ fname = f"{name}.yaml" if not name.endswith(".yaml") else name
+
+ # yaml under llava/data/registry/datasets
+ repo_path = osp.join(osp.dirname(__file__), "registry", "datasets", fname)
+ if osp.exists(repo_path):
+ return repo_path
+
+ # # yaml under
+ abs_path = osp.expanduser(fname)
+ if osp.exists(abs_path):
+ return abs_path
+
+ raise FileNotFoundError(f"Dataset '{name}' is not found in the {repo_path} or {abs_path}.")
+
+
+def register_datasets(name: Optional[str] = None):
+ if name is None:
+ name = os.environ.get("VILA_DATASETS", "default")
+ logger.info(f"Registering datasets from environment: '{name}'.")
+ # return io.load(osp.join(osp.dirname(__file__), "registry", "datasets", f"{name}.yaml"))
+ dataset_meta = {}
+ for _name in name.split(","):
+ yamlpath = load_dataset_yaml(_name)
+ logger.info(f"Registering datasets from: '{yamlpath}'.")
+ meta = io.load(yamlpath)
+ dataset_meta.update(meta)
+ return dataset_meta
+
+
+def register_mixtures():
+ return io.load(os.path.join(os.path.dirname(__file__), "registry", "mixtures.yaml"))
+
+
+DATASETS = register_datasets()
+MIXTURES = register_mixtures()
+
+
+def parse_mixture(mixture: str) -> List[str]:
+ names = mixture.split("+") if "+" in mixture else [mixture]
+ while any(name in MIXTURES for name in names):
+ names = list(chain(*[MIXTURES.get(name, [name]) for name in names]))
+ return sorted(names)
+
+
+class SubsetDataset(Dataset):
+ def __init__(self, dataset: Dataset, limit: int) -> None:
+ super().__init__()
+ self.dataset = dataset
+ self.limit = limit
+
+ def __len__(self) -> int:
+ return int(len(self.dataset) * self.limit)
+
+ def __getitem__(self, index: int) -> Any:
+ return self.dataset[index % len(self.dataset)]
+
+class RepeatedDataset(Dataset):
+ def __init__(self, dataset: Dataset, times: int) -> None:
+ super().__init__()
+ self.dataset = dataset
+ self.times = times
+
+ def __len__(self) -> int:
+ return len(self.dataset) * self.times
+
+ def __getitem__(self, index: int) -> Any:
+ return self.dataset[index % len(self.dataset)]
+
+
+def get_world_size():
+ if torch.distributed.is_initialized():
+ return torch.distributed.get_world_size()
+ else:
+ return 1
+
+
+def build_dataset(
+ mixture: str,
+ data_args: DataArguments,
+ training_args: TrainingArguments,
+ tokenizer: PreTrainedTokenizer,
+) -> Dataset:
+ logger.warning(f"Training VILA with mixture '{mixture}'.")
+ datasets = []
+ dataset_rng = np.random.default_rng(1234)
+ for name in parse_mixture(mixture):
+
+ if "*" in name:
+ name, times = name.split("*")
+ times = int(times)
+ else:
+ times = 1
+ limit_dataset = False
+ if "#" in name:
+ # we limit the max length of this dataset
+ name, max_length_percent = name.split("#")
+ limit_dataset = True
+ if DATASETS is not None and name in DATASETS:
+ if name in DATASETS_LEGACY:
+ logger.warning(f"Dataset '{name}' exists in both new and legacy registries. Using the new one.")
+ dataset = instantiate(DATASETS[name], _partial_=True)(
+ tokenizer=tokenizer,
+ data_args=data_args,
+ global_batch_size=(
+ training_args.per_device_train_batch_size
+ # * torch.distributed.get_world_size()
+ * get_world_size()
+ * training_args.gradient_accumulation_steps
+ ),
+ )
+ elif name in DATASETS_LEGACY:
+ logger.warning(f"Dataset '{name}' is from the legacy registry. Please consider migrating it.")
+ dataset = build_dataset_legacy(
+ name,
+ data_args=data_args,
+ training_args=training_args,
+ tokenizer=tokenizer,
+ )
+ else:
+ raise ValueError(f"Dataset '{name}' is not found in the registries.")
+
+
+ if limit_dataset:
+ # we limit the max length of this dataset
+ max_length = int(float(int(max_length_percent) / 100.) * len(dataset))
+ dataset = SubsetDataset(dataset, float(int(max_length_percent) / 100.))
+
+ if times > 1:
+ dataset = RepeatedDataset(dataset, times)
+ datasets.append(dataset)
+ return ConcatDataset(datasets)
+
+
+def build_dataset_legacy(
+ name: str,
+ data_args: DataArguments,
+ training_args: TrainingArguments,
+ tokenizer: PreTrainedTokenizer,
+) -> Dataset:
+ from llava.data.dataset import (
+ LazySupervisedDataset,
+ LazyWDSDataset,
+ )
+
+ dataset = DATASETS_LEGACY[name]
+ dataset_type = dataset.dataset_type
+ if dataset_type == "torch":
+ dataset_cls = LazySupervisedDataset
+ elif dataset_type == "wds":
+ dataset_cls = LazyWDSDataset
+ else:
+ raise NotImplementedError(f"{dataset_type} is not supported.")
+
+ data_args.meta_path = getattr(dataset, "meta_path", None)
+ data_args.caption_choice = getattr(dataset, "caption_choice", None)
+ data_args.caption_choice_2 = getattr(dataset, "caption_choice_2", None)
+ data_args.start_idx = getattr(dataset, "start_idx", None)
+ data_args.end_idx = getattr(dataset, "end_idx", None)
+
+ return dataset_cls(
+ tokenizer=tokenizer,
+ data_path=dataset.data_path,
+ image_folder=getattr(dataset, "image_path"),
+ data_args=data_args,
+ training_args=training_args,
+ )
diff --git a/llava/data/collate.py b/llava/data/collate.py
new file mode 100644
index 0000000000000000000000000000000000000000..4795b77a246acce9294f07b189f0a06725b71f2c
--- /dev/null
+++ b/llava/data/collate.py
@@ -0,0 +1,166 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from dataclasses import dataclass
+from typing import Any, Dict, Sequence
+
+import torch
+from transformers import PreTrainedTokenizer
+
+from llava.constants import IGNORE_INDEX
+from llava.utils.logging import logger
+
+__all__ = ["DataCollator"]
+
+
+@dataclass
+class DataCollator:
+ tokenizer: PreTrainedTokenizer
+
+ def __init__(self, tokenizer: PreTrainedTokenizer):
+ super().__init__()
+ self.tokenizer = tokenizer
+
+ def __call__(self, instances: Sequence[Dict[str, Any]]) -> Dict[str, Any]:
+ # Gather everything from the batch
+ input_ids, labels, media, block_sizes = [], [], {name: [] for name in self.tokenizer.media_tokens}, []
+
+ media_meta = {}
+
+ media_meta["sound_feature_masks"] = []
+ media_meta["sound_embed_masks"] = []
+ media_meta["frame_times"] = []
+ for instance in instances:
+ if isinstance(instance["input_ids"], torch.Tensor):
+ input_ids.append(instance["input_ids"])
+ labels.append(instance["labels"])
+ for name in media:
+ objs = instance.get(name)
+ objs = objs if objs is not None else []
+ media[name].append([obj for obj in objs])
+ if instance.get("sound") is not None:
+ for name_k in media_meta:
+ if "sound" in name_k:
+ objs = instance.get(name_k)
+ media_meta[name_k].append([obj for obj in objs])
+ if instance.get("video") is not None or instance.get("image") is not None:
+ for name_k in media_meta:
+ if "frame" in name_k:
+ objs = instance.get(name_k)
+ media_meta[name_k].append([obj for obj in objs])
+ if "block_sizes" in instance:
+ block_sizes.append(instance["block_sizes"])
+ else:
+ block_sizes.append(
+ [None for _ in range(len(instance.get("image")))] if instance.get("image") is not None else []
+ )
+ else:
+ input_ids.extend(instance["input_ids"])
+ labels.extend(instance["labels"])
+ for name in media:
+ objs = instance.get(name)
+ objs = objs if objs is not None else [[] for _ in range(len(instance["input_ids"]))]
+ media[name].extend(objs)
+ if instance.get("sound") is not None:
+ for name_k in media_meta:
+ if "sound" in name_k:
+ objs = instance.get(name_k)
+ media_meta[name_k].extend(objs)
+ if instance.get("video") is not None or instance.get("image") is not None:
+ for name_k in media_meta:
+ if "frame" in name_k:
+ objs = instance.get(name_k)
+ media_meta[name_k].append([obj for obj in objs])
+ if "block_sizes" in instance:
+ block_sizes.extend(instance["block_sizes"])
+ else:
+ block_sizes.extend(
+ [[None for _ in range(len(objs))] for objs in instance.get("image")]
+ if instance.get("image") is not None
+ else [[] for _ in range(len(instance["input_ids"]))]
+ )
+
+ batch_size = len(input_ids)
+
+
+ # Check if the number of media objects (or the number of block sizes) matches the number of media tokens
+ for name in media:
+ for k in range(batch_size):
+ if name == "image" and not all([_ is None for _ in block_sizes[k]]):
+ actual = len(block_sizes[k])
+ else:
+ actual = len(media[name][k])
+ expected = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item()
+ if actual != expected:
+ raise ValueError(
+ f"Number mismatch between {name} objects and {name} tokens. "
+ f"There are {expected} {name} tokens but {actual} {name} objects."
+ )
+
+ # Batchify the inputs
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ input_ids,
+ batch_first=True,
+ padding_value=self.tokenizer.pad_token_id,
+ )
+ labels = torch.nn.utils.rnn.pad_sequence(
+ labels,
+ batch_first=True,
+ padding_value=IGNORE_INDEX,
+ )
+ input_ids = input_ids[:, : self.tokenizer.model_max_length]
+ labels = labels[:, : self.tokenizer.model_max_length]
+ attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
+
+ # Truncate media objects if necessary
+ for name in media:
+ objects = []
+ for k in range(batch_size):
+ if name == "image" and not all([_ is None for _ in block_sizes[k]]):
+ actual = len(media[name][k])
+ num_large_scale_blocks = sum([x * y for x, y in block_sizes[k]])
+ num_small_scale_blocks = actual - num_large_scale_blocks
+ num_small_scale_blocks_each_img = num_small_scale_blocks // len(block_sizes[k])
+ expected_full_image = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item()
+ expected = (
+ sum([x * y for x, y in block_sizes[k][:expected_full_image]])
+ + num_small_scale_blocks_each_img * expected_full_image
+ )
+ if actual > expected:
+ logger.warning(f"Truncating the number of {name} objects from {actual} to {expected}")
+ media[name][k] = media[name][k][:expected]
+ objects.extend(media[name][k])
+ block_sizes[k] = block_sizes[k][:expected_full_image]
+ else:
+ actual = len(media[name][k])
+ expected = (input_ids[k] == self.tokenizer.media_token_ids[name]).sum().item()
+ if actual > expected:
+ logger.warning(f"Truncating the number of {name} objects from {actual} to {expected}")
+ media[name][k] = media[name][k][:expected]
+ objects.extend(media[name][k])
+ if name == "image":
+ block_sizes[k] = block_sizes[k][:expected]
+ media[name] = objects
+
+ for name in media_meta:
+ objects = []
+ for k in range(batch_size):
+ try:
+ objects.extend(media_meta[name][k])
+ except:
+ continue
+ media_meta[name] = objects
+
+ # Flatten block sizes from [[bls_im1_instance1, bls_im2_instance1], [bls_im1_instance2, bls_im2_instance2], ...] to [bls_im1_instance1, bls_im2_instance1, bls_im1_instance2, bls_im2_instance2, ...]
+ block_sizes = sum(block_sizes, [])
+ return {
+ "input_ids": input_ids,
+ "media": media,
+ "media_config": {"image": {"block_sizes": block_sizes}, "video": {}, "speech": {}, "sound": {}},
+ "labels": labels,
+ "attention_mask": attention_mask,
+ "media_meta": media_meta,
+ }
diff --git a/llava/data/dataset.py b/llava/data/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..29c8b4c12ceca36f709b4ac87ab18fefae57df71
--- /dev/null
+++ b/llava/data/dataset.py
@@ -0,0 +1,1635 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import base64
+import copy
+import io
+import json
+import os
+import os.path as osp
+import random
+import time
+import warnings
+from dataclasses import dataclass
+from typing import Dict, Sequence
+import math
+import numpy as np
+import PIL
+import torch
+import transformers
+from PIL import Image, ImageFile
+from torch.utils.data import Dataset, default_collate
+from transformers import PreTrainedTokenizer
+from transformers import AutoFeatureExtractor
+import kaldiio
+import llava.data.datasets_mixture as datasets_mixture
+from llava import conversation as conversation_lib
+from llava.constants import DEFAULT_SOUND_TOKEN,DEFAULT_SPEECH_TOKEN, IGNORE_INDEX
+from llava.data.collate import DataCollator
+from llava.mm_utils import (
+ load_audio,
+ get_num_windows,
+ tokenizer_image_token,
+)
+from torchvision import transforms
+from llava.train.args import DataArguments, TrainingArguments
+from llava.train.sequence_parallel import (
+ extract_local_from_list,
+ extract_local_input_ids,
+ extract_local_position_ids,
+ get_pg_manager,
+)
+from llava.utils.tokenizer import preprocess_conversation
+# import torchaudio
+from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, UniformClipSampler
+import soundfile as sf
+from librosa import resample as librosa_resample
+import whisper
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+PIL.Image.MAX_IMAGE_PIXELS = 1000000000
+
+def int16_to_float32(x):
+ return (x / 32767.0).astype(np.float32)
+
+
+def float32_to_int16(x):
+ x = np.clip(x, a_min=-1., a_max=1.)
+ return (x * 32767.).astype(np.int16)
+
+
+
+def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict:
+ is_multimodal = data_args.is_multimodal
+ if not is_multimodal:
+ return sources
+
+ for source in sources:
+ concat_values = "".join([sentence["value"] for sentence in source])
+ for sid, sentence in enumerate(source):
+ # In multimodal conversations, we automatically prepend '' at the start of the first sentence if it doesn't already contain one.
+
+ if DEFAULT_SOUND_TOKEN in sentence["value"]:
+ sentence["value"] = sentence["value"].replace(DEFAULT_SOUND_TOKEN, f"{DEFAULT_SOUND_TOKEN}\n")
+ sentence["value"] = sentence["value"].replace(f"{DEFAULT_SOUND_TOKEN}\n\n", f"{DEFAULT_SOUND_TOKEN}\n")
+ if DEFAULT_SPEECH_TOKEN in sentence["value"]:
+ sentence["value"] = sentence["value"].replace(DEFAULT_SPEECH_TOKEN, f"{DEFAULT_SPEECH_TOKEN}\n")
+ sentence["value"] = sentence["value"].replace(f"{DEFAULT_SPEECH_TOKEN}\n\n", f"{DEFAULT_SPEECH_TOKEN}\n")
+ return sources
+
+
+def preprocess_plain(
+ sources: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ # add end signal and concatenate together
+ conversations = []
+ for source in sources:
+ assert len(source) == 2
+ assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
+ source[0]["value"] = DEFAULT_IMAGE_TOKEN
+ conversation = source[0]["value"] + source[1]["value"] + conversation_lib.default_conversation.sep
+ conversations.append(conversation)
+ # tokenize conversations
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations]
+ targets = copy.deepcopy(input_ids)
+ for target, source in zip(targets, sources):
+ tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer))
+ target[:tokenized_len] = IGNORE_INDEX
+
+ return dict(input_ids=input_ids, labels=targets)
+
+
+def preprocess(
+ sources: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False,
+ no_system_prompt: bool = False,
+) -> Dict:
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
+ return preprocess_plain(sources, tokenizer)
+ return default_collate(
+ [
+ preprocess_conversation(conversation, tokenizer, no_system_prompt=no_system_prompt)
+ for conversation in sources
+ ]
+ )
+
+
+class LazySupervisedDataset(Dataset):
+ """Dataset for supervised fine-tuning.
+ This class is originally implemented by the LLaVA team and modified by
+ Ji Lin and Haotian Tang.
+ """
+
+ def __init__(
+ self,
+ data_path: str,
+ image_folder: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ data_args: DataArguments,
+ training_args: TrainingArguments,
+ ):
+ super().__init__()
+ try:
+ with open(data_path) as fp:
+ list_data_dict = json.load(fp)
+ except:
+ with open(data_path) as fp:
+ list_data_dict = [json.loads(q) for q in fp]
+
+ # rank0_print("Formatting inputs...Skip in lazy mode")
+ print("Formatting inputs...Skip in lazy mode")
+ self.tokenizer = tokenizer
+ self.list_data_dict = list_data_dict
+ self.data_args = data_args
+ self.image_folder = image_folder
+ self.wav_processor = AutoFeatureExtractor.from_pretrained('/lustre/fsw/portfolios/adlr/users/sreyang/flamingo_v2/NV-Whisper')
+
+ def __len__(self):
+ return len(self.list_data_dict)
+
+ @property
+ def lengths(self):
+ length_list = []
+ for sample in self.list_data_dict:
+ img_tokens = 128 if "image" in sample else 0
+ length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens)
+ return length_list
+
+ @property
+ def modality_lengths(self):
+ length_list = []
+ for sample in self.list_data_dict:
+ if 'duration' in sample.keys():
+ duration = sample["duration"]
+ else:
+ duration = 10.
+ try:
+ cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"]) + int(math.ceil(duration * 25))
+ cur_len = cur_len if "sound" in sample else -cur_len
+ length_list.append(cur_len)
+ except:
+ try:
+ cur_len = 0 + int(math.ceil(duration * 25))
+ cur_len = cur_len if "sound" in sample else -cur_len
+ length_list.append(cur_len)
+ except:
+ cur_len = 0 + int(math.ceil(10. * 25))
+ cur_len = cur_len if "sound" in sample else -cur_len
+ length_list.append(cur_len)
+ return length_list
+
+ @staticmethod
+ def _load_sound(sound_file, wav_processor, sample_rate=16000, window_length=30.0, window_overlap=0.0, max_num_window=3, audio_start = 0.0):
+ if sound_file is None:
+ return None
+ window_length = int(window_length * sample_rate)
+ window_overlap = int(window_overlap * sample_rate)
+ max_num_window = int(max_num_window)
+ duration = max_num_window * (window_length - window_overlap) + window_overlap
+
+ sound_outputs = []
+ audio_feature_masks = []
+ audio_embed_masks = []
+
+ try:
+ sound_filename = str.split(sound_file, '/')[-1]
+ if '.ark' in sound_filename:
+ sound = kaldiio.load_mat(sound_file)
+ audio_data = sound[1]
+ audio_data=audio_data.astype(np.float16)
+ else:
+ audio_data = load_audio(sound_file, sample_rate, duration, audio_start) # already cuts to max duration
+ T = len(audio_data)
+ audio_data = audio_data.reshape(1, -1)
+ num_windows, full_length = get_num_windows(T, sample_rate, max_num_window)
+
+ audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float()
+
+ for i in range(num_windows):
+ audio_embed_mask = torch.zeros(750)
+ start = i * (window_length - window_overlap)
+ audio_data_tensor_this = audio_data_tensor[:, start:start+window_length]
+ orig_length = audio_data_tensor_this.shape[1]
+ audio_data_tensor_this = wav_processor(audio_data_tensor_this.cpu().numpy(), sampling_rate=sample_rate, return_tensors="pt") #.squeeze(0) text="dummy", audios=audio_data_tensor_this, return_tensors="pt") #
+ sound_outputs.append(audio_data_tensor_this["input_features"])
+ # calculate the mask for the input melspec to Whisper
+ melspec_frames_this_window = int(math.ceil(orig_length / 160))
+ feature_attention_mask = torch.zeros(3000, dtype=torch.int32)
+ feature_attention_mask[:melspec_frames_this_window] = 1
+ audio_feature_masks.append(feature_attention_mask.unsqueeze(0))
+ # calculate the mask for the output embedding for use in AF2
+ conv_lengths = (melspec_frames_this_window - 1) // 2 + 1
+ output_embedding_lengths = (conv_lengths - 2) // 2 + 1
+ audio_embed_mask[:output_embedding_lengths] = 1
+ audio_embed_masks.append(audio_embed_mask)
+ except:
+ print('error loading file', sound_file)
+ sound_outputs.append(torch.zeros(1,128,3000))
+ audio_feature_masks.append(torch.zeros(1,3000, dtype=torch.int32))
+ audio_embed_masks.append(torch.zeros(750))
+
+ return torch.stack(sound_outputs, dim=0), torch.stack(audio_feature_masks, dim=0), torch.stack(audio_embed_masks, dim=0)
+
+ @staticmethod
+ def _load_speech(speech_path,sample_rate=16000):
+ if speech_path is None:
+ return None
+
+ speech_outputs = []
+ try:
+ speech = whisper.load_audio(speech_path)
+ speech = whisper.pad_or_trim(speech)
+ mel = whisper.log_mel_spectrogram(speech)
+ speech_outputs.append(mel.unsqueeze(0))
+ except:
+ speech_outputs.append(torch.zeros(1,80,3000))
+ return torch.stack(speech_outputs, dim=0)
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ sources = self.list_data_dict[i]
+ if isinstance(i, int):
+ sources = [sources]
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
+
+ import re
+ if "sound" in self.list_data_dict[i]:
+ # chat data loading
+ if isinstance(self.list_data_dict[i]["sound"],list):
+ sound_files = self.list_data_dict[i]["sound"]
+ conversations_raw = self.list_data_dict[i]["conversations"]
+
+ # Step 1: Extract tags in order of appearance
+ sound_tag_pattern = re.compile(r"")
+ ordered_sound_tags = []
+
+ for turn in conversations_raw:
+ tags = sound_tag_pattern.findall(turn["value"])
+ ordered_sound_tags.extend([f"" for tag in tags])
+
+ # Step 2: Load sound tensors in the order of tags
+ sound_tensor = []
+ audio_feature_masks = []
+ audio_embed_masks = []
+ sound_token_map = {}
+
+ for tag in ordered_sound_tags:
+ idx = int(tag.split('-')[1][:-1])
+ if tag not in sound_token_map:
+ this_sound_tensor, af_mask, ae_mask = self._load_sound(sound_file, self.wav_processor, max_num_window=self.data_args.audio_frames)
+ this_sound_tensor = this_sound_tensor.squeeze(1) # (windows x 750 x 2048)
+ sound_token_map[tag] = ("\n" * this_sound_tensor.shape[0]).rstrip()
+ sound_tensor.append(this_sound_tensor)
+ audio_feature_masks.append(af_mask)
+ audio_embed_masks.append(ae_mask)
+ else:
+ # If already loaded, still append to match sequence
+ this_sound_tensor, af_mask, ae_mask = self._load_sound(sound_file, self.wav_processor, max_num_window=self.data_args.audio_frames)
+ this_sound_tensor = this_sound_tensor.squeeze(1)
+ sound_tensor.append(this_sound_tensor)
+ audio_feature_masks.append(af_mask)
+ audio_embed_masks.append(ae_mask)
+
+
+ # Process conversations and inject sound markers
+ conversation = []
+ for turn in conversations_raw:
+ role = turn["from"]
+ value = turn["value"]
+
+ # Replace any tag with corresponding repeated \n
+ for tag, sound_token in sound_token_map.items():
+ value = value.replace(tag, sound_token)
+
+ conversation.append({
+ "from": role,
+ "value": value.rstrip()
+ })
+
+ sources = [conversation]
+ sound_tensor = torch.cat(sound_tensor, dim=0)
+ audio_feature_masks = torch.cat(audio_feature_masks, dim=0)
+ audio_embed_masks = torch.cat(audio_embed_masks, dim=0)
+ else:
+ sound_file = self.list_data_dict[i]["sound"]
+ question = str(self.list_data_dict[i]["conversations"][0]["value"].rstrip())
+ answer = str(self.list_data_dict[i]["conversations"][1]["value"]).rstrip()
+ question = question.replace("\n", "").replace("\n", "").replace("", "")
+ question = question.replace("\n", "").replace("\n", "").replace("", "")
+ question = question.replace("\n", "").replace("\n", "").replace("", "")
+ question = question.replace("\n", "").replace("\n", "").replace("", "")
+ sound_tensor, audio_feature_masks, audio_embed_masks = self._load_sound(sound_file, self.wav_processor, max_num_window=self.data_args.audio_frames)
+ sound_tensor=sound_tensor.squeeze(1) # squeeze the irrelevant dimension which was caused due to processor getting 1 batch for processing --> (windows x 750 x 2048)
+ question = "\n" * sound_tensor.shape[0] + question
+ conversation = [
+ {"from": "human", "value": question},
+ {"from": "gpt", "value": answer},
+ ]
+
+ sources = [conversation]
+ data_dict = preprocess(
+ sources,
+ self.tokenizer,
+ has_image=(
+ "sound" in self.list_data_dict[i]
+ ),
+ )
+ if isinstance(i, int):
+ data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
+
+ if "sound" in self.list_data_dict[i]:
+ data_dict["sound"] = sound_tensor
+ data_dict["sound_feature_masks"] = audio_feature_masks
+ data_dict["sound_embed_masks"] = audio_embed_masks
+ if "speech" in self.list_data_dict[i]:
+ data_dict["speech"] = speech_tensor
+
+ return data_dict
+
+
+class LazyMMC4Dataset(Dataset):
+ """Dataset for supervised fine-tuning.
+ This class is implemented by Ji Lin and Haotian Tang."""
+
+ def __init__(
+ self,
+ data_path: str,
+ image_folder: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ data_args: DataArguments,
+ training_args: TrainingArguments,
+ image_following_text_only=False,
+ text_only=False,
+ ):
+ super().__init__()
+
+ import pickle
+
+ n_samples = []
+ # actually shards and stats info
+ n_shards = len(os.listdir(data_path)) // 2
+ # n_shards = 100
+ count_info_list = sorted([f for f in os.listdir(data_path) if f.endswith(".count")])[:n_shards]
+ n_samples = [int(open(os.path.join(data_path, f)).read().strip()) for f in count_info_list]
+
+ print("total MMC4 samples", sum(n_samples)) # 10,881,869
+
+ PROCESS_GROUP_MANAGER = get_pg_manager()
+ if PROCESS_GROUP_MANAGER is not None:
+ import torch.distributed as dist
+
+ sequence_parallel_size = training_args.seq_parallel_size
+ else:
+ sequence_parallel_size = 1
+ print("sequence_parallel_size", sequence_parallel_size)
+ rank = training_args.process_index // sequence_parallel_size # int(os.environ["RANK"])
+ world_size = training_args.world_size // sequence_parallel_size # int(os.environ["WORLD_SIZE"])
+ shared_size = n_shards // world_size
+
+ gpu_samples = [sum(n_samples[i * shared_size : (i + 1) * shared_size]) for i in range(world_size)]
+ self.n_samples = min(gpu_samples) * world_size # total size
+ self.idx_offset = rank * min(gpu_samples)
+ shard_start, shard_end = rank * shared_size, (rank + 1) * shared_size
+ print(f" * loading data from shard {shard_start}-{shard_end}")
+
+ shard_names = [d.replace(".count", ".pkl") for d in count_info_list]
+ shard_names = shard_names[shard_start:shard_end]
+
+ full_data_list = []
+ # now load data
+ for shard_name in shard_names:
+ # load shard
+ with open(os.path.join(data_path, shard_name), "rb") as f:
+ data_list = pickle.load(f)
+
+ full_data_list.extend(data_list)
+
+ print(f"* loaded totally {len(full_data_list)} samples")
+
+ self.data_list = full_data_list
+
+ self.tokenizer = tokenizer
+ self.data_args = data_args
+ self.image_folder = image_folder
+
+ self.image_following_text_only = image_following_text_only
+ self.text_only = text_only
+
+ def __len__(self):
+ # return len(self.data_list)
+ return self.n_samples
+
+ @property
+ def modality_lengths(self):
+ # Estimate the number of tokens after tokenization, used for length-grouped sampling
+ length_list = []
+ for info in self.data_list:
+ num_images = min(6, len(info["image_info"]))
+ sentences = [info["text_list"][x["matched_text_index"]] for x in info["image_info"][:num_images]]
+ # The unit of cur_len is "words". We assume 1 word = 2 tokens.
+ cur_len = num_images * self.num_image_tokens // 2 + sum([len(x) for x in sentences])
+ length_list.append(cur_len)
+ return length_list
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ info = self.data_list[i - self.idx_offset]
+
+ sentences = info["text_list"]
+ # kentang-mit@: remove existing tokens in the sentences
+ for ix in range(len(sentences)):
+ # if this is an html tag, we still preserve its semantic meaning
+ sentences[ix] = sentences[ix].replace("", "")
+ sim_matrix = info["similarity_matrix"] # we do not use this...
+
+ # convert images from base64 to PIL and filter based on image-text similarity
+ images, sentence_ixs = [], []
+ if not self.text_only:
+ for sample_image, sim_vec in zip(info["image_info"], sim_matrix):
+ image_base64 = sample_image["image_base64"]
+ rawbytes = base64.b64decode(image_base64)
+
+ sim_ix = sample_image["matched_text_index"]
+ # sim_ix = np.argmax(sim_vec)
+ # sim_score = sim_vec[sim_ix]
+
+ # filter to images >= 5KB
+ # if len(rawbytes) // 1000 <= 5:
+ # continue
+ # if sim_score < 0.24:
+ # continue
+ image = Image.open(io.BytesIO(rawbytes)).convert("RGB")
+
+ images.append(image)
+ sentence_ixs.append(sim_ix)
+
+ # constrain max num 6 images
+ max_num_images = 6
+ if len(images) > max_num_images:
+ images = images[:max_num_images]
+ sentence_ixs = sentence_ixs[:max_num_images]
+
+ # reorder images according to text insertion
+ images = [images[iii] for iii in np.argsort(sentence_ixs)]
+
+ # preprocess and tokenize text
+ for ix in sentence_ixs:
+ sentences[ix] = f"\n{sentences[ix]}"
+
+ if self.image_following_text_only:
+ # use pad tokens to divide sentence pieces
+ text = self.tokenizer.pad_token.join(sentences)
+ else:
+ text = " ".join(sentences)
+ # whitespace cleanup
+ text = text.replace(" ", "").replace(" ", "")
+ text = f"{text}{self.tokenizer.eos_token}" # add eos token
+
+ if len(images) > 0:
+ if self.data_args.image_aspect_ratio == "dynamic_s2":
+ images, block_sizes = dynamic_s2_process_images_and_prompt(
+ images, text, self.data_args, self.image_folder
+ )
+ elif self.data_args.image_aspect_ratio == "dynamic":
+ images, text = dynamic_process_images_and_prompt(
+ images, text, self.data_args, self.image_folder, max_tiles=6
+ )
+ else:
+ images = torch.stack([process_image(image, self.data_args, self.image_folder) for image in images])
+
+ # the same size for all images, so we concat
+ # cur_token_len = (
+ # images[0].shape[-2] // self.multimodal_cfg["patch_size"]
+ # ) * (images[0].shape[-1] // self.multimodal_cfg["patch_size"])
+ # cur_token_len += self.multimodal_cfg["n_extra_patch"]
+ else:
+ images = None
+ # cur_token_len = 0
+
+ input_ids = tokenizer_image_token(
+ text,
+ self.tokenizer,
+ return_tensors="pt",
+ )
+
+ image_token_id = self.tokenizer.media_token_ids["image"]
+
+ # now check the case where the last token is image patch token
+ if input_ids[-1] == image_token_id: # need to remove one last image
+ last_non_im_patch_indices = torch.where(input_ids != image_token_id)[0][-1] + 1
+ input_ids = input_ids[:last_non_im_patch_indices]
+
+ n_im_patch = (input_ids == image_token_id).sum().item()
+
+ if self.data_args.image_aspect_ratio != "dynamic_s2":
+ images = images[:n_im_patch]
+ assert len(images) == n_im_patch, print(text, input_ids)
+ assert len(input_ids.shape) == 1, "Unexpected shape of 'input_ids' from MMC4."
+ input_ids = (
+ torch.concat([torch.tensor([self.tokenizer.bos_token_id]), input_ids])
+ if self.tokenizer.bos_token_id is not None and input_ids[0] != self.tokenizer.bos_token_id
+ else input_ids
+ )
+ targets = input_ids.clone()
+
+ if self.image_following_text_only: # keep only text after leading image token
+ # remove loss for any token before the first token
+ label_idx = 0
+ while label_idx < targets.shape[-1] and targets[label_idx] != image_token_id:
+ targets[label_idx] = IGNORE_INDEX
+ label_idx += 1
+
+ pad_token = self.tokenizer.convert_tokens_to_ids([self.tokenizer.pad_token])[0]
+
+ pad_token_idxs = torch.where(targets == pad_token)[0]
+ for pad_token_idx in pad_token_idxs:
+ token_idx = pad_token_idx + 1
+ while token_idx < targets.shape[-1] and targets[token_idx] != image_token_id:
+ targets[token_idx] = IGNORE_INDEX
+ token_idx += 1
+ # do not train on padding tokens
+ targets[targets == pad_token] = IGNORE_INDEX
+
+ # mask image tokens is unnecessary for llava-1.5
+ # targets[targets == IMAGE_TOKEN_INDEX] = IGNORE_INDEX
+ # print(input_ids.shape)
+
+ data_dict = dict(input_ids=input_ids, labels=targets, image=images)
+ if self.data_args.image_aspect_ratio == "dynamic_s2":
+ data_dict["block_sizes"] = block_sizes
+
+ return data_dict
+
+
+class LazyCoyoDataset(Dataset):
+ """Dataset for supervised fine-tuning.
+ This class is implemented by Ji Lin and Haotian Tang."""
+
+ num_image_tokens = 576
+
+ def __init__(
+ self,
+ data_path: str,
+ image_folder: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ data_args: DataArguments,
+ training_args: TrainingArguments,
+ # kentang-mit@: balance the total number of tokens for Coyo and MMC4.
+ n_samples_per_idx=4,
+ ):
+ super().__init__()
+
+ import pickle
+
+ n_samples = []
+ # actually shards and stats info
+ n_shards = len(os.listdir(data_path)) // 2
+ # n_shards = 100
+ count_info_list = sorted([f for f in os.listdir(data_path) if f.endswith(".count")])[:n_shards]
+ n_samples = [int(open(os.path.join(data_path, f)).read().strip()) for f in count_info_list]
+
+ print("total COYO samples", sum(n_samples))
+
+ PROCESS_GROUP_MANAGER = get_pg_manager()
+ if PROCESS_GROUP_MANAGER is not None:
+ import torch.distributed as dist
+
+ sequence_parallel_size = training_args.seq_parallel_size
+ else:
+ sequence_parallel_size = 1
+ print("sequence_parallel_size", sequence_parallel_size)
+ rank = training_args.process_index // sequence_parallel_size # int(os.environ["RANK"])
+ world_size = training_args.world_size // sequence_parallel_size # int(os.environ["WORLD_SIZE"])
+ shared_size = n_shards // world_size
+
+ gpu_samples = [
+ sum(n_samples[i * shared_size : (i + 1) * shared_size]) // n_samples_per_idx for i in range(world_size)
+ ]
+ self.n_samples = min(gpu_samples) * world_size # total size
+ self.idx_offset = rank * min(gpu_samples)
+
+ shard_start, shard_end = rank * shared_size, (rank + 1) * shared_size
+ print(f" * loading data from shard {shard_start}-{shard_end}")
+
+ shard_names = [d.replace(".count", ".pkl") for d in count_info_list]
+ shard_names = shard_names[shard_start:shard_end]
+
+ full_data_list = []
+ # now load data
+ for shard_name in shard_names:
+ # load shard
+ with open(os.path.join(data_path, shard_name), "rb") as f:
+ shard_data = pickle.load(f)
+ random.seed(42)
+ if "mmc4" in data_path:
+ random.shuffle(shard_data) # shuffle for MMC4cap only
+ full_data_list.extend(shard_data)
+
+ print(f"* loaded totally {len(full_data_list)} samples")
+
+ # now pack the samples into groups
+ n_groups = len(full_data_list) // n_samples_per_idx
+ full_data_list = [
+ full_data_list[i : i + n_samples_per_idx] for i in range(0, len(full_data_list), n_samples_per_idx)
+ ]
+ if len(full_data_list[-1]) < n_samples_per_idx:
+ full_data_list = full_data_list[:-1]
+ assert len(full_data_list) == n_groups
+ print(f"split into {n_groups} groups")
+
+ self.data_list = full_data_list
+
+ self.tokenizer = tokenizer
+ self.data_args = data_args
+ self.image_folder = image_folder
+
+ def __len__(self):
+ # return len(self.data_list)
+ return self.n_samples
+
+ @property
+ def modality_lengths(self):
+ # Estimate the number of tokens after tokenization, used for length-grouped sampling
+ length_list = []
+ for samples in self.data_list:
+ cur_len = sum([len(conv["text" if "text" in conv else "caption"].split()) for conv in samples])
+ # The unit of cur_len is "words". We assume 1 word = 2 tokens.
+ cur_len = cur_len + len(samples) * self.num_image_tokens // 2
+ length_list.append(cur_len)
+ return length_list
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ CONCAT_SAMPLES = False
+ info_list = self.data_list[i - self.idx_offset]
+
+ text_list = []
+ image_list = []
+
+ for sample in info_list:
+ caption_key = (
+ "text" if "text" in sample else "caption"
+ ) # kentang-mit@: remove existing tokens in the sentences
+ # kentang-mit@: remove existing token.
+ # if this is an html tag, we still preserve its semantic meaning
+ sample[caption_key] = sample[caption_key].replace("", "")
+ text_list.append(DEFAULT_IMAGE_TOKEN + "\n" + sample[caption_key] + self.tokenizer.eos_token)
+ if "image" in sample:
+ image_base64 = sample["image"]
+ rawbytes = base64.b64decode(image_base64)
+ else:
+ rawbytes = sample["rawbytes"]
+ image = Image.open(io.BytesIO(rawbytes)).convert("RGB")
+ image_list.append(image)
+
+ image_list = torch.stack([process_image(image, self.data_args, self.image_folder) for image in image_list])
+
+ if CONCAT_SAMPLES:
+ # into capcap...
+ text_list = "".join(text_list)
+
+ input_ids = self.tokenizer(
+ text_list,
+ return_tensors="pt",
+ padding="longest",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids # 4, seq_len
+
+ input_ids = input_ids[0]
+
+ else:
+ input_ids = [
+ tokenizer_image_token(
+ prompt,
+ self.tokenizer,
+ return_tensors="pt",
+ )
+ for prompt in text_list
+ ]
+ # print([x.shape[0] for x in input_ids], [len(x.split()) for x in text_list], [len(re.findall(r"]*>", x)) for x in text_list])
+
+ # input_ids = torch.nn.utils.rnn.pad_sequence(
+ # input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
+ # )
+
+ targets = copy.deepcopy(input_ids)
+ for i in range(len(targets)):
+ targets[i][targets[i] == self.tokenizer.pad_token_id] = IGNORE_INDEX
+
+ return dict(input_ids=input_ids, labels=targets, image=image_list)
+
+
+class LazyWDSDataset(Dataset):
+ """Dataset for supervised fine-tuning.
+ This class is implemented by Ji Lin and Ligeng Zhu."""
+
+ def __init__(
+ self,
+ data_path: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ data_args: DataArguments,
+ image_folder: str,
+ training_args: TrainingArguments,
+ ):
+ super().__init__()
+ n_samples = []
+ n_shards = len(os.listdir(data_path)) // 3
+ for shard in range(n_shards):
+ with open(os.path.join(data_path, f"{shard:05d}_stats.json")) as f:
+ info = json.load(f)
+ n_samples.append(info["successes"])
+
+ # print(f"[DEBUG] {data_path} total samples", sum(n_samples)) # 10,881,869
+
+ PROCESS_GROUP_MANAGER = get_pg_manager()
+ if PROCESS_GROUP_MANAGER is not None:
+ import torch.distributed as dist
+
+ sequence_parallel_size = training_args.seq_parallel_size
+ else:
+ sequence_parallel_size = 1
+ print("sequence_parallel_size", sequence_parallel_size)
+ rank = training_args.process_index // sequence_parallel_size # int(os.environ["RANK"])
+ world_size = training_args.world_size // sequence_parallel_size # int(os.environ["WORLD_SIZE"])
+ shared_size = n_shards // world_size
+ print("rank", rank, "world_size", world_size, "shared_size", shared_size)
+ gpu_samples = [sum(n_samples[i * shared_size : (i + 1) * shared_size]) for i in range(world_size)]
+ self.n_samples = min(gpu_samples) * world_size # total size
+ self.idx_offset = rank * min(gpu_samples)
+ shard_start, shard_end = rank * shared_size, (rank + 1) * shared_size
+ print(f" * loading data from shard {shard_start}-{shard_end}")
+
+ tar_list = [f"{shard_idx:05d}.tar" for shard_idx in range(shard_start, shard_end)]
+
+ self.data_list = []
+ t1 = time.time()
+ for tar in tar_list:
+ tmp_path = f"/tmp/ccs{tar}"
+ tar_path = os.path.join(data_path, tar)
+
+ if PROCESS_GROUP_MANAGER is not None:
+ dist.barrier()
+ if PROCESS_GROUP_MANAGER.sp_rank == 0:
+ os.makedirs(tmp_path, exist_ok=True)
+ os.system(f"tar -xkf {tar_path} -C {tmp_path}")
+ dist.barrier()
+ else:
+ os.makedirs(tmp_path, exist_ok=True)
+ os.system(f"tar -xkf {tar_path} -C {tmp_path}")
+
+ txt_list = [f for f in os.listdir(tmp_path) if f.endswith(".txt")]
+
+ for txt in txt_list:
+ caption = open(os.path.join(tmp_path, txt)).read().strip()
+ image_path = os.path.join(tmp_path, txt.split(".")[0] + ".jpg")
+ self.data_list.append({"caption": caption, "image": image_path})
+ t2 = time.time()
+ print(f"Loading done. Total time: {t2 - t1:.2f} seconds")
+
+ self.tokenizer = tokenizer
+ self.data_args = data_args
+ self.image_folder = image_folder
+
+ def __len__(self):
+ return self.n_samples
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+
+ # print("i", i, "idx_offset", self.idx_offset, "len", len(self.data_list))
+ info = self.data_list[i - self.idx_offset]
+ caption, image_path = info["caption"], info["image"]
+
+ rand_prompt = "\n"
+ sources = [
+ {
+ "image": image_path,
+ "conversations": [
+ {"from": "human", "value": rand_prompt},
+ {"from": "gpt", "value": caption},
+ ],
+ }
+ ]
+
+ # one example of sources
+ # [{'id': 'GCC_train_001738742', 'image': 'GCC_train_001738742.jpg', 'conversations': [{'from': 'human', 'value': 'Provide a brief description of the given image.\n'}, {'from': 'gpt', 'value': 'a sketch of an ostrich'}]}]
+ if "image" in sources[0]:
+ image = process_image(sources[0]["image"], self.data_args, self.image_folder)
+ image = torch.unsqueeze(image, dim=0)
+ # now random pick some context samples for training
+ if hasattr(self.data_args, "num_shots"):
+ if self.data_args.num_shots > 0:
+ raise NotImplementedError
+ else:
+ raise NotImplementedError
+
+ data_dict = preprocess([sources[0]["conversations"]], self.tokenizer, has_image=True)
+
+ if isinstance(i, int):
+ data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
+
+ # image exist in the data
+ if image is not None:
+ data_dict["image"] = image
+ else:
+ raise NotImplementedError
+
+ return data_dict
+
+
+class LazyCCSWebDataset(Dataset):
+ """Dataset for supervised fine-tuning.
+ This class is implemented by Ligeng Zhu."""
+
+ def __init__(
+ self,
+ data_path: str,
+ image_folder: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ data_args: DataArguments,
+ training_args: TrainingArguments,
+ ):
+ super().__init__()
+ t1 = time.time()
+
+ from llava.data.simple_vila_webdataset import VILAWebDataset
+
+ print("[DEBUG] ", osp.abspath(data_path))
+ self.dataset = VILAWebDataset(data_path=osp.abspath(data_path))
+
+ t2 = time.time()
+ print(f"Loading done. Total time: {t2 - t1:.2f} seconds")
+
+ self.tokenizer = tokenizer
+ self.data_args = data_args
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ # info = self.data_list[i - self.idx_offset]
+ # caption, image_path = info["caption"], info["image"]
+ info = self.dataset[i]
+ if ".jpg" in info:
+ caption, image_path = info[".txt"], info[".jpg"]
+ elif ".png" in info:
+ caption, image_path = info[".txt"], info[".png"]
+ elif ".webp" in info:
+ caption, image_path = info[".txt"], info[".webp"]
+ elif ".bmp" in info:
+ caption, image_path = info[".txt"], info[".bmp"]
+ elif ".tiff" in info:
+ caption, image_path = info[".txt"], info[".tiff"]
+ else:
+ print(info.keys())
+ print(info)
+ raise KeyError
+
+ caption = caption.replace("", "")
+ if isinstance(image_path, io.BytesIO):
+ image_path = Image.open(image_path).convert("RGB")
+
+ if not isinstance(image_path, PIL.Image.Image):
+ print(image_path)
+ print(info.keys())
+ print(type(image_path))
+ raise NotImplementedError
+
+ rand_prompt = "\n"
+ sources = [
+ {
+ "image": image_path,
+ "conversations": [
+ {"from": "human", "value": rand_prompt},
+ {"from": "gpt", "value": caption},
+ ],
+ }
+ ]
+
+ # one example of sources
+ # [{'id': 'GCC_train_001738742', 'image': 'GCC_train_001738742.jpg', 'conversations': [{'from': 'human', 'value': 'Provide a brief description of the given image.\n'}, {'from': 'gpt', 'value': 'a sketch of an ostrich'}]}]
+ if "image" in sources[0]:
+ image = process_image(sources[0]["image"], self.data_args, image_folder=None)
+ image = torch.unsqueeze(image, dim=0)
+ # now random pick some context samples for training
+ if hasattr(self.data_args, "num_shots"):
+ if self.data_args.num_shots > 0:
+ raise NotImplementedError
+ else:
+ raise NotImplementedError
+
+ data_dict = preprocess([sources[0]["conversations"]], self.tokenizer, has_image=True)
+
+ if isinstance(i, int):
+ data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
+
+ # image exist in the data
+ if image is not None:
+ data_dict["image"] = image
+ else:
+ raise NotImplementedError
+
+ return data_dict
+
+
+from functools import lru_cache
+
+
+@lru_cache(maxsize=16)
+def lru_json_load(fpath):
+ with open(fpath) as fp:
+ return json.load(fp)
+
+
+class LazyCoyoWebDataset(Dataset):
+ """Dataset for supervised fine-tuning.
+ This class is implemented by Ligeng Zhu."""
+
+ num_image_tokens = 576
+
+ def __init__(
+ self,
+ data_path: str,
+ image_folder: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ data_args: DataArguments,
+ training_args: TrainingArguments,
+ # kentang-mit@: balance the total number of tokens for Coyo and MMC4.
+ n_samples_per_idx=4,
+ ):
+ super().__init__()
+
+ from llava.data.simple_vila_webdataset import VILAWebDataset
+
+ print("[DEBUG] ", osp.abspath(data_path))
+ self.dataset = VILAWebDataset(data_path=osp.abspath(data_path), meta_path=data_args.meta_path)
+
+ if data_args.start_idx >= 0 and data_args.end_idx >= 0:
+ # Ligeng: support slicing for ablate different subsets.
+ total = len(self.dataset)
+ start_idx = int(total * data_args.start_idx)
+ end_idx = int(total * data_args.end_idx)
+ print(f"loading subset from {start_idx} to {end_idx}, total {total}")
+ self.dataset = torch.utils.data.Subset(self.dataset, range(start_idx, end_idx))
+
+ # For caption choice,
+ # if None: use original caption
+ # if a folder path: use specified caption to override original one (choice1)
+ # if a folder path: use specified caption and concat with original one (choice2)
+ self.caption_choice = None
+ self.caption_choice_2 = None
+ self.data_path = data_path
+
+ if data_args.caption_choice is not None:
+ self.caption_choice = data_args.caption_choice
+ print("[recap] Override coyo caption using ", self.caption_choice)
+
+ if data_args.caption_choice_2 is not None:
+ self.caption_choice_2 = data_args.caption_choice_2
+ print("[recapv2] Override coyo caption using ", self.caption_choice_2)
+
+ print("total samples", len(self.dataset))
+ PROCESS_GROUP_MANAGER = get_pg_manager()
+ if PROCESS_GROUP_MANAGER is not None:
+ import torch.distributed as dist
+
+ sequence_parallel_size = training_args.seq_parallel_size
+ sequence_parallel_rank = PROCESS_GROUP_MANAGER.sp_rank
+ else:
+ sequence_parallel_size = 1
+ print("sequence_parallel_size", sequence_parallel_size)
+ rank = (
+ training_args.process_index // sequence_parallel_size if "RANK" in os.environ else 2
+ ) # int(os.environ["RANK"])
+ world_size = (
+ training_args.world_size // sequence_parallel_size if "WORLD_SIZE" in os.environ else 32
+ ) # int(os.environ["WORLD_SIZE"])
+ print(
+ "rank",
+ rank,
+ "world_size",
+ world_size,
+ )
+
+ self.n_samples_per_idx = n_samples_per_idx
+ # self.n_samples = len(self.dataset) // n_samples_per_idx
+ self.tokenizer = tokenizer
+ self.data_args = data_args
+
+ def __len__(self):
+ return len(self.dataset) // self.n_samples_per_idx
+
+ @property
+ def modality_lengths(self):
+ # Estimate the number of tokens after tokenization, used for length-grouped sampling
+ length_list = []
+ for samples in self.data_list:
+ cur_len = sum([len(conv["text" if "text" in conv else "caption"].split()) for conv in samples])
+ # The unit of cur_len is "words". We assume 1 word = 2 tokens.
+ cur_len = cur_len + len(samples) * self.num_image_tokens // 2
+ length_list.append(cur_len)
+ return length_list
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ CONCAT_SAMPLES = False
+ # info_list = self.dataset[i - self.idx_offset]
+
+ begin_idx, end_idx = (
+ i * self.n_samples_per_idx,
+ (i + 1) * self.n_samples_per_idx,
+ )
+ end_idx = min(end_idx, len(self.dataset))
+
+ text_list = []
+ image_list = []
+
+ for idx in range(begin_idx, end_idx):
+ info = self.dataset[idx]
+ if ".jpg" in info:
+ caption, image_path = info[".txt"], info[".jpg"]
+ elif ".png" in info:
+ caption, image_path = info[".txt"], info[".png"]
+ elif ".webp" in info:
+ caption, image_path = info[".txt"], info[".webp"]
+ elif ".bmp" in info:
+ caption, image_path = info[".txt"], info[".bmp"]
+ elif ".tiff" in info:
+ caption, image_path = info[".txt"], info[".tiff"]
+ else:
+ print(info.keys())
+ print(info)
+ raise KeyError
+
+ if self.caption_choice is not None:
+ # load new captions
+ shard = info["__shard__"]
+ url = info[".json"]["url"]
+ tar_name = osp.relpath(osp.realpath(shard), osp.realpath(self.data_path))
+ # tar_name = osp.dirname(shard)
+ shard_json_path = osp.join(self.caption_choice, tar_name + ".json")
+ try:
+ shard_json = lru_json_load(shard_json_path)
+ try:
+ caption = shard_json[url]["output"]
+ except KeyError:
+ print(f"{url} not in caption. fallback to original caption temporarially")
+ except:
+ print(f"shard_json_path {shard_json_path} not found. fallback to original caption temporarially")
+ caption = caption.replace("", "")
+ text_list.append(DEFAULT_IMAGE_TOKEN + caption + self.tokenizer.eos_token)
+
+ if isinstance(image_path, io.BytesIO):
+ image_path = Image.open(image_path).convert("RGB")
+
+ if not isinstance(image_path, PIL.Image.Image):
+ print(image_path)
+ print(info.keys())
+ print(type(image_path))
+ raise NotImplementedError
+
+ image_list.append(image_path)
+
+ # image_list = torch.stack([process_image(image, self.data_args, image_folder=None) for image in image_list])
+ # NOTE(fix by ligeng)
+ # now image_list should return a list of image tensor where each has a dimension of (1, c, h, w)
+ image_list = [process_image(image, self.data_args, image_folder=None).unsqueeze(0) for image in image_list]
+
+ if CONCAT_SAMPLES:
+ # into capcap...
+ text_list = "".join(text_list)
+
+ input_ids = self.tokenizer(
+ text_list,
+ return_tensors="pt",
+ padding="longest",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids # 4, seq_len
+
+ input_ids = input_ids[0]
+ else:
+ input_ids = [
+ tokenizer_image_token(
+ prompt,
+ self.tokenizer,
+ return_tensors="pt",
+ )
+ for prompt in text_list
+ ]
+ input_ids = [
+ (
+ torch.concat([torch.tensor([self.tokenizer.bos_token_id]), input_ids_i])
+ if input_ids_i[0] != self.tokenizer.bos_token_id
+ else input_ids_i
+ )
+ for input_ids_i in input_ids
+ ]
+
+ targets = copy.deepcopy(input_ids)
+ for i in range(len(targets)):
+ targets[i][targets[i] == self.tokenizer.pad_token_id] = IGNORE_INDEX
+
+ return dict(input_ids=input_ids, labels=targets, image=image_list)
+
+
+class LazyVideoWebDataset(Dataset):
+ """Dataset for supervised fine-tuning."""
+
+ def __init__(
+ self,
+ data_path: str,
+ image_folder: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ data_args: DataArguments,
+ training_args: TrainingArguments,
+ # cache_path: str,
+ # n_samples_per_idx=4,
+ ):
+ super().__init__()
+
+ # from llava.data.simple_video_dataset import SimpleVideoDataset
+
+ from llava.data.simple_vila_webdataset import VILAWebDataset
+
+ print("[DEBUG] ", osp.abspath(data_path))
+ self.dataset = VILAWebDataset(
+ data_path=osp.abspath(data_path),
+ meta_path=f"{osp.abspath(data_path)}/wids-meta.json",
+ # cache_dir=cache_path,
+ )
+
+ # None: use original caption
+ # Folder path: use original caption
+ self.caption_choice = None
+ self.data_path = data_path
+
+ if data_args.caption_choice is not None:
+ self.caption_choice = data_args.caption_choice
+ print("[recap] Override LazyVideo caption using ", self.caption_choice)
+
+ print("total samples", len(self.dataset))
+ # InternVid: TODO
+ PROCESS_GROUP_MANAGER = get_pg_manager()
+ if PROCESS_GROUP_MANAGER is not None:
+ import torch.distributed as dist
+
+ sequence_parallel_size = training_args.seq_parallel_size
+ sequence_parallel_rank = PROCESS_GROUP_MANAGER.sp_rank
+ else:
+ sequence_parallel_size = 1
+ print("sequence_parallel_size", sequence_parallel_size)
+ rank = (
+ training_args.process_index // sequence_parallel_size if "RANK" in os.environ else 2
+ ) # int(os.environ["RANK"])
+ world_size = (
+ training_args.world_size // sequence_parallel_size if "WORLD_SIZE" in os.environ else 32
+ ) # int(os.environ["WORLD_SIZE"])
+ print(
+ "rank",
+ rank,
+ "world_size",
+ world_size,
+ )
+ self.rank = rank
+ # rank = int(os.environ["RANK"]) if "RANK" in os.environ else 2
+ # world_size = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 32
+
+ self.tokenizer = tokenizer
+ self.data_args = data_args
+
+ self.missing_uids = set()
+
+ def __len__(self):
+ return len(self.dataset)
+
+ @property
+ def modality_lengths(self):
+ # Estimate the number of tokens after tokenization, used for length-grouped sampling
+ length_list = []
+ for samples in self.data_list:
+ cur_len = sum([len(conv["text" if "text" in conv else "caption"].split()) for conv in samples])
+ # The unit of cur_len is "words". We assume 1 word = 2 tokens.
+ cur_len = cur_len + len(samples) * self.num_image_tokens // 2
+ length_list.append(cur_len)
+ return length_list
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ ADD_TEXT_PROMPT = False
+ num_video_frames = self.data_args.num_video_frames if hasattr(self.data_args, "num_video_frames") else 8
+ loader_fps = self.data_args.fps if hasattr(self.data_args, "fps") else 0.0
+
+ info = self.dataset[i]
+
+ caption = ""
+ # print(info)
+ if ".mp4" in info:
+ caption, video_path = info[".txt"], info[".mp4"]
+ else:
+ video_path = None
+ caption = "Empty video."
+
+ images, frames_loaded, _ = LazySupervisedDataset._load_video(
+ video_path, num_video_frames, loader_fps, self.data_args
+ )
+
+ if frames_loaded == 0:
+ caption = "Empty video."
+
+ if self.caption_choice is not None:
+ shard = info["__shard__"]
+ uuid = osp.join(info["__shard__"], info["__key__"])
+ url = info["__key__"]
+ tar_name = osp.basename(info["__shard__"])
+
+ try:
+ shard_json_path = osp.join(self.caption_choice, tar_name.replace(".tar", ".json"))
+ shard_json = lru_json_load(shard_json_path)
+ caption = shard_json[url]["summary"]["output"]
+ except (KeyError, FileNotFoundError, json.decoder.JSONDecodeError):
+ if uuid not in self.missing_uids:
+ print("override caption not found for ", uuid)
+ self.missing_uids.add(uuid)
+
+ # print(f"[DEBUG {uuid}]", caption)
+
+ frames_loaded_successfully = len(images)
+ if caption is None:
+ caption = ""
+ prompt = "\n" * frames_loaded_successfully + caption
+ image_tensor = torch.stack([process_image(image, self.data_args, None) for image in images])
+
+ input_ids = tokenizer_image_token(
+ prompt,
+ self.tokenizer,
+ return_tensors="pt",
+ )
+ targets = copy.deepcopy(input_ids)
+ data_dict = dict(input_ids=input_ids, labels=targets, image=image_tensor)
+
+ return data_dict
+
+
+class DataCollatorForSupervisedDatasetSeqParallel:
+ """Collate examples for supervised fine-tuning.
+ This class is originally implemented by the LLaVA team and
+ modified by Haotian Tang."""
+
+ def __init__(
+ self,
+ tokenizer: transformers.PreTrainedTokenizer,
+ data_args: DataArguments,
+ training_args: TrainingArguments,
+ sp_degree: int,
+ sp_rank: int,
+ ring_degree: int,
+ ring_type: str,
+ ):
+ self.tokenizer = tokenizer
+ self.data_args = data_args
+ self.training_args = training_args
+ self.sp_degree = sp_degree
+ self.sp_rank = sp_rank
+ self.ring_degree = ring_degree
+ self.ring_type = ring_type
+
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
+ input_ids, labels, images = [], [], []
+ image_token_id = self.tokenizer.media_token_ids["image"]
+ video_token_id = self.tokenizer.media_token_ids["video"]
+
+ for instance in instances:
+ if not isinstance(instance["input_ids"], list):
+ input_ids.append(instance["input_ids"])
+ else:
+ input_ids += instance["input_ids"]
+ if not isinstance(instance["labels"], list):
+ labels.append(instance["labels"])
+ else:
+ labels += instance["labels"]
+ # Note (kentang-mit@: we do not directly push tensors to
+ # images, but list of tensors.
+ if "video" in instance:
+ instance["image"] = torch.cat(instance["video"])
+ video_id_pos = torch.where(input_ids[-1] == video_token_id)[0][0]
+ replace_ids = torch.Tensor(
+ ([image_token_id] + self.tokenizer.encode("\n")) * instance["image"].shape[0],
+ device=input_ids[-1].device,
+ )
+ input_ids[-1] = torch.cat(
+ [input_ids[-1][:video_id_pos], replace_ids, input_ids[-1][video_id_pos + 1 :]]
+ ).to(input_ids[-1].dtype)
+ labels[-1] = torch.cat(
+ [
+ labels[-1][:video_id_pos],
+ torch.Tensor([IGNORE_INDEX] * instance["image"].shape[0] * 2),
+ labels[-1][video_id_pos + 1 :],
+ ]
+ ).to(labels[-1].dtype)
+ instance.pop("video")
+
+ if "image" in instance:
+ cur_image = instance["image"]
+ assert len(cur_image.shape) == 4
+ # n_images, 3, size, size
+ if cur_image.shape[0] == 0:
+ warnings.warn("loaded one sample without images.")
+ if not isinstance(instance["input_ids"], list):
+ # datasets other than coyo, not packing >1 samples together
+ images.append(cur_image)
+ else:
+ # coyo-like datasets
+ images.extend(cur_image.chunk(cur_image.size(0), 0))
+ else:
+ warnings.warn("loaded one sample without images.")
+ images.append([])
+ # kentang-mit@: we need to make sure these two lists have
+ # the same length. We will use input_ids to filter out images corresponding
+ # to truncated tokens later.
+
+ max_num_images = max([len(_images) for _images in images])
+ for _images, _input_ids in zip(images, input_ids):
+ assert (
+ len(_images) == (_input_ids == image_token_id).sum().item()
+ ), f"Number mismatch between images and placeholder image tokens in 'len(_images) == (_input_ids == image_token_id).sum().item()'.\
+ Expect to have {len(_images)} images but only found {(_input_ids == image_token_id).sum().item()} images in tokens. \
+ Error input_ids: {_input_ids} {self.tokenizer.decode([x if x != -200 else 200 for x in _input_ids])}"
+
+ NUM_TOKENS_PER_IMAGE = self.data_args.num_image_tokens
+ if hasattr(self.data_args.image_processor, "crop_size"):
+ crop_size = self.data_args.image_processor.crop_size
+ else:
+ crop_size = self.data_args.image_processor.size
+
+ # Init the padding sample
+ seq_id = 0
+ while seq_id < len(input_ids):
+ # Skip the samples without images
+ dummy_image = torch.ones((1, 3, crop_size["height"], crop_size["width"]), device=input_ids[seq_id].device)
+ # dummy input_ids include one bos, one image token, and one eos
+ dummy_input_ids = torch.zeros_like(input_ids[seq_id][:3])
+ dummy_input_ids[0] = self.tokenizer.bos_token_id
+ dummy_input_ids[1] = image_token_id
+ dummy_input_ids[2] = self.tokenizer.eos_token_id
+ dummy_labels = copy.deepcopy(dummy_input_ids)
+ dummy_labels[:2] = IGNORE_INDEX
+ dummy_seqlen = NUM_TOKENS_PER_IMAGE + 2 # TODO: Check the hard coding of 2
+ dummy_position_ids = torch.arange(start=0, end=dummy_seqlen, dtype=torch.int32)
+ break
+
+ # Sort with the real length of the sequence
+ combined = sorted(
+ zip(input_ids, labels, images),
+ key=lambda x: len(x[2]) * (NUM_TOKENS_PER_IMAGE - 1) + x[0].size(-1),
+ reverse=True, # Start Packing from the sequence with most images.
+ )
+ sorted_ids, sorted_labels, sorted_images = zip(*combined)
+ sorted_ids, sorted_labels, sorted_images = list(sorted_ids), list(sorted_labels), list(sorted_images)
+ max_seq_length = self.tokenizer.model_max_length # len(sorted_ids[0])
+ max_sample_len = 0
+
+ batches = []
+ label_batches = []
+ position_ids = []
+ batch_images = []
+ seqlens_in_batch = []
+
+ i = 0
+ while i < len(sorted_ids):
+ current_batch = torch.tensor([], dtype=torch.int32)
+ current_label_batch = torch.tensor([], dtype=torch.int32)
+ current_position_ids = torch.tensor([], dtype=torch.int32)
+ current_batch_images = []
+ current_num_images = 0
+ current_len = 0
+ current_num_samples = 0
+
+ # Pack a few samples into one sample
+ while i < len(sorted_ids):
+ num_images = (sorted_ids[i] == image_token_id).sum().item()
+ num_image_tokens_added = num_images * (NUM_TOKENS_PER_IMAGE - 1)
+ num_incoming_tokens = sorted_ids[i].size(-1) + num_image_tokens_added
+
+ # Handle RingAttn_Varlen which requires `seqlens_in_batch` should be divisible by `ring_degree`
+ if self.ring_degree > 1:
+ RING_PAD_TOKEN_INDEX = 2
+ if self.ring_type == "ring_varlen":
+ if num_incoming_tokens % self.sp_degree != 0:
+ pad_len = self.sp_degree - num_incoming_tokens % self.sp_degree
+ num_incoming_tokens += pad_len
+ # pad `input_ids`
+ pad_tensor = torch.full(
+ (pad_len,), RING_PAD_TOKEN_INDEX, dtype=sorted_ids[i].dtype, device=sorted_ids[i].device
+ )
+ sorted_ids[i] = torch.cat([sorted_ids[i], pad_tensor])
+
+ # pad `label`
+ pad_label_tensor = torch.full(
+ (pad_len,), IGNORE_INDEX, dtype=sorted_labels[i].dtype, device=sorted_labels[i].device
+ )
+ sorted_labels[i] = torch.cat([sorted_labels[i], pad_label_tensor])
+ elif self.ring_type == "zigzag_ring_varlen":
+ self.zigzag_sp_degree = self.sp_degree * 2
+ if num_incoming_tokens % self.zigzag_sp_degree != 0:
+ pad_len = self.zigzag_sp_degree - num_incoming_tokens % self.zigzag_sp_degree
+ num_incoming_tokens += pad_len
+ # pad `input_ids`
+ pad_tensor = torch.full(
+ (pad_len,), RING_PAD_TOKEN_INDEX, dtype=sorted_ids[i].dtype, device=sorted_ids[i].device
+ )
+ sorted_ids[i] = torch.cat([sorted_ids[i], pad_tensor])
+
+ # pad `label`
+ pad_label_tensor = torch.full(
+ (pad_len,), IGNORE_INDEX, dtype=sorted_labels[i].dtype, device=sorted_labels[i].device
+ )
+ sorted_labels[i] = torch.cat([sorted_labels[i], pad_label_tensor])
+ else:
+ raise ValueError(f"Invalid ring_type: {self.ring_type}")
+
+ if num_incoming_tokens > max_seq_length:
+ print(
+ f"Warning: Skipping one packed sample with {num_incoming_tokens} tokens,\
+ please consider increase max seq len {max_seq_length}."
+ )
+ i += 1
+ continue
+
+ if (
+ (current_num_images == 0)
+ or (current_num_images < self.sp_degree)
+ or (
+ (current_num_images + num_images <= max_num_images)
+ and (current_len + num_incoming_tokens <= max_sample_len)
+ )
+ ) and (current_len + num_incoming_tokens <= max_seq_length):
+ current_num_images += num_images
+ current_len += num_incoming_tokens
+ current_num_samples += 1
+ current_position_ids = torch.cat(
+ (current_position_ids, torch.arange(start=0, end=num_incoming_tokens)), dim=0
+ )
+ current_batch = torch.cat((current_batch, sorted_ids[i]), dim=0)
+ sorted_labels[i][0] = IGNORE_INDEX
+ current_label_batch = torch.cat((current_label_batch, sorted_labels[i]), dim=0)
+ seqlens_in_batch.append(num_incoming_tokens)
+ current_batch_images.extend(sorted_images[i])
+ i += 1
+ assert current_num_images == len(current_batch_images)
+ else:
+ break
+
+ # Padding the sample with the dummy image sample, if there are no enough images
+ MAX_RETRY = self.sp_degree
+ num_retry = 0
+ while current_num_images < self.sp_degree and current_len < max_seq_length and num_retry <= MAX_RETRY:
+ current_num_images += dummy_image.size(0)
+ current_len += dummy_seqlen
+ current_num_samples += 1
+ current_position_ids = torch.cat((current_position_ids, dummy_position_ids), dim=0)
+ current_batch = torch.cat((current_batch, dummy_input_ids), dim=0)
+ current_label_batch = torch.cat((current_label_batch, dummy_labels), dim=0)
+ seqlens_in_batch.append(dummy_seqlen)
+ current_batch_images.extend(dummy_image)
+ # We pad from left side to ensure correct grad flow
+ # current_batch = torch.cat((dummy_input_ids, current_batch), dim=0)
+ # current_label_batch = torch.cat((dummy_labels, current_label_batch), dim=0)
+ # seqlens_in_batch.insert(0, dummy_seqlen)
+ # current_batch_images = torch.cat((dummy_image, current_batch_images), dim=0)
+ num_retry += 1
+
+ # Drop the samples that do not have enough images
+ if current_num_images < self.sp_degree:
+ print(f"Warning: Skipping one packed sample with {current_num_images} images")
+ seqlens_in_batch = seqlens_in_batch[:-current_num_samples]
+ continue
+
+ max_sample_len = max(max_sample_len, current_len)
+ batches.append(current_batch)
+ label_batches.append(current_label_batch)
+ position_ids.append(current_position_ids)
+ batch_images.append(current_batch_images)
+
+ try:
+ assert current_num_images == len(torch.where(current_batch == image_token_id)[0].tolist())
+ except AssertionError:
+ print(f"Error num_images on {self.sp_rank}", current_num_images)
+ print("current_batch", current_batch)
+ print(
+ f"Error len(torch.where(batches[i] == image_token_id)[0].tolist() on {self.sp_rank}:",
+ len(torch.where(current_batch == image_token_id)[0].tolist()),
+ )
+ print(f"Error len(current_batch_images) on {self.sp_rank}:", len(current_batch_images))
+ raise AssertionError
+
+ # Split for sequence parallelism
+ for i in range(len(batches)):
+ image_token_indices = torch.where(batches[i] == image_token_id)[0].tolist()
+ image_ids = torch.arange(0, len(image_token_indices), dtype=torch.int32)
+ batches[i] = extract_local_input_ids(
+ batches[i], image_token_indices, self.sp_rank, self.sp_degree, self.tokenizer.bos_token_id
+ )
+ label_batches[i] = extract_local_input_ids(
+ label_batches[i], image_token_indices, self.sp_rank, self.sp_degree, self.tokenizer.bos_token_id
+ )
+ batch_images[i] = torch.concat(
+ extract_local_from_list(batch_images[i], self.sp_rank, self.sp_degree), dim=0
+ )
+ H, W = batch_images[i].size(-2), batch_images[i].size(-1)
+ batch_images[i] = batch_images[i].reshape(-1, 3, W, H)
+ num_images = len(batch_images[i])
+
+ try:
+ assert num_images == len(torch.where(batches[i] == image_token_id)[0].tolist())
+ except AssertionError:
+ print(f"Error num_images on {self.sp_rank}", num_images)
+ print("batches[i]", batches[i])
+ print(
+ f"Error len(torch.where(batches[i] == image_token_id)[0].tolist() on {self.sp_rank}:",
+ len(torch.where(batches[i] == image_token_id)[0].tolist()),
+ )
+ print(f"Error batch_images[i] on {self.sp_rank}:", batch_images[i].shape)
+ raise AssertionError
+ position_ids[i] = extract_local_position_ids(
+ position_ids[i], image_token_indices, image_ids, self.sp_rank, self.sp_degree, NUM_TOKENS_PER_IMAGE - 1
+ )
+
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ batches, batch_first=True, padding_value=self.tokenizer.pad_token_id
+ )
+ labels = torch.nn.utils.rnn.pad_sequence(label_batches, batch_first=True, padding_value=IGNORE_INDEX)
+ seqlens_in_batch = [torch.tensor(x) for x in seqlens_in_batch]
+ seqlens_in_batch = torch.stack(seqlens_in_batch, axis=0)
+ seqlens_in_batch = seqlens_in_batch.flatten()
+ position_ids = torch.nn.utils.rnn.pad_sequence(position_ids, batch_first=True, padding_value=-1)
+
+ if batch_images:
+ batch_images = [torch.unbind(images) for images in batch_images]
+ flat_batch_images = [item for sublist in batch_images for item in sublist]
+ else:
+ flat_batch_images = None
+ batch = dict(
+ input_ids=input_ids,
+ labels=labels,
+ # notice that we inject attention mask here
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
+ seqlens_in_batch=seqlens_in_batch,
+ media={"image": flat_batch_images},
+ media_config={"image": {}},
+ position_ids=position_ids,
+ )
+ return batch
+
+
+def make_supervised_data_module(
+ tokenizer: PreTrainedTokenizer,
+ data_args: DataArguments,
+ training_args: TrainingArguments,
+) -> Dict:
+ """Make dataset and collator for supervised fine-tuning.
+ This function is originally implemented by the LLaVA team and
+ modified by Jason Lu, Haotian Tang and Ligeng Zhu."""
+ datasets_mixture.register_datasets_mixtures()
+
+ from .builder import build_dataset
+
+ train_dataset = build_dataset(data_args.data_mixture, data_args, training_args, tokenizer)
+ training_args.sample_lens = [len(d) for d in train_dataset.datasets]
+
+ PROCESS_GROUP_MANAGER = get_pg_manager()
+ if PROCESS_GROUP_MANAGER is None:
+ data_collator = DataCollator(tokenizer=tokenizer)
+ else:
+ sp_degree = training_args.seq_parallel_size
+ sp_rank = PROCESS_GROUP_MANAGER.sp_rank
+ ring_degree = PROCESS_GROUP_MANAGER.ring_degree
+ ring_type = PROCESS_GROUP_MANAGER.ring_type
+ data_collator = DataCollatorForSupervisedDatasetSeqParallel(
+ tokenizer=tokenizer,
+ data_args=data_args,
+ training_args=training_args,
+ sp_degree=sp_degree,
+ sp_rank=sp_rank,
+ ring_degree=ring_degree,
+ ring_type=ring_type,
+ )
+
+ return dict(
+ train_dataset=train_dataset,
+ data_collator=data_collator,
+ )
diff --git a/llava/data/datasets_mixture.py b/llava/data/datasets_mixture.py
new file mode 100644
index 0000000000000000000000000000000000000000..62a53fcf5fe69256e5cd4b005ab7f28f63d9fe7c
--- /dev/null
+++ b/llava/data/datasets_mixture.py
@@ -0,0 +1,80 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import warnings
+from dataclasses import dataclass, field
+
+
+@dataclass
+class Dataset:
+ dataset_name: str
+ dataset_type: str = field(default="torch")
+ data_path: str = field(default=None, metadata={"help": "Path to the training data."})
+ meta_path: str = field(default=None, metadata={"help": "Path to the meta data for webdataset."})
+ image_path: str = field(default=None, metadata={"help": "Path to the training image data."})
+ speech_path: str = field(default=None, metadata={"help": "Path to the training speech data."})
+ caption_choice: str = field(default=None, metadata={"help": "Path to the caption directory for recaption."})
+ description: str = field(
+ default=None,
+ metadata={
+ "help": "Detailed desciption of where the data is from, how it is labelled, intended use case and the size of the dataset."
+ },
+ )
+ test_script: str = (None,)
+ maintainer: str = (None,)
+ ############## ############## ############## ############## ############## ##############
+ caption_choice: str = field(default=None, metadata={"help": "Path to the captions for webdataset."})
+ caption_choice_2: str = field(default=None, metadata={"help": "Path to the captions for webdataset."})
+ start_idx: float = field(default=-1, metadata={"help": "Start index of the dataset."})
+ end_idx: float = field(default=-1, metadata={"help": "Start index of the dataset."})
+
+
+DATASETS_LEGACY = {}
+
+
+def add_dataset(dataset):
+ if dataset.dataset_name in DATASETS_LEGACY:
+ # make sure the data_name is unique
+ warnings.warn(f"{dataset.dataset_name} already existed in DATASETS. Make sure the name is unique.")
+ assert "+" not in dataset.dataset_name, "Dataset name cannot include symbol '+'."
+ DATASETS_LEGACY.update({dataset.dataset_name: dataset})
+
+
+def register_datasets_mixtures():
+ ############## ############## ############## ############## ############## ##############
+ # Audio Datasets
+ ############## ############## ############## ############## ############## ##############
+
+ data_mixture_1 = Dataset(
+ dataset_name="data_mixture_1",
+ dataset_type="torch",
+ data_path="/path/to/your/data_mixture_1/train.json",
+ )
+ add_dataset(data_mixture_1)
+
+ data_mixture_2 = Dataset(
+ dataset_name="data_mixture_2",
+ dataset_type="torch",
+ data_path="/path/to/your/data_mixture_2/train.json",
+ )
+ add_dataset(data_mixture_2)
+ # Add more data mixtures below
\ No newline at end of file
diff --git a/llava/data/registry/datasets/audio_test.yaml b/llava/data/registry/datasets/audio_test.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..144c86b4bee8d3bbe67b32b7e7948e6f3c8e4bb9
--- /dev/null
+++ b/llava/data/registry/datasets/audio_test.yaml
@@ -0,0 +1,97 @@
+---
+Clotho-AQA-AQA:
+ _target_: llava.data.LLaVADataset
+ data_path: Clotho-AQA-AQA/test.json
+Music-AVQA-AQA_All:
+ _target_: llava.data.LLaVADataset
+ data_path: Music-AVQA-AQA_All/test.json
+CochlScene-SceneClassification:
+ _target_: llava.data.LLaVADataset
+ data_path: CochlScene-SceneClassification/test.json
+NSynth-Source:
+ _target_: llava.data.LLaVADataset
+ data_path: NSynth-Source/test.json
+NSynth-Instrument:
+ _target_: llava.data.LLaVADataset
+ data_path: NSynth-Instrument/test.json
+FSD50k-EventClassification:
+ _target_: llava.data.LLaVADataset
+ data_path: FSD50k-EventClassification/test.json
+Clotho-v2-AudioCaptioning:
+ _target_: llava.data.LLaVADataset
+ data_path: Clotho-v2-AudioCaptioning/test.json
+audiocaps-AudioCaptioning:
+ _target_: llava.data.LLaVADataset
+ data_path: audiocaps-AudioCaptioning/test.json
+ravdess-EmotionClassification:
+ _target_: llava.data.LLaVADataset
+ data_path: ravdess-EmotionClassification/val.json
+GTZAN-GenreClassification:
+ _target_: llava.data.LLaVADataset
+ data_path: GTZAN-GenreClassification/test.json
+UrbanSound8K-EventClassification:
+ _target_: llava.data.LLaVADataset
+ data_path: UrbanSound8K-EventClassification/train.json
+Medley-solos-DB-InstrClassification:
+ _target_: llava.data.LLaVADataset
+ data_path: Medley-solos-DB-InstrClassification/test.json
+ESC50-EventClassification:
+ _target_: llava.data.LLaVADataset
+ data_path: ESC50-EventClassification/train.json
+CREMA-D-EmotionClassification:
+ _target_: llava.data.LLaVADataset
+ data_path: CREMA-D-EmotionClassification/test.json
+IEMOCAP-EmotionClassification:
+ _target_: llava.data.LLaVADataset
+ data_path: IEMOCAP-EmotionClassification/test.json
+MELD-EmotionClassification:
+ _target_: llava.data.LLaVADataset
+ data_path: MELD-EmotionClassification/test.json
+MELD-SentimentClassification:
+ _target_: llava.data.LLaVADataset
+ data_path: MELD-SentimentClassification/test.json
+MMAU:
+ _target_: llava.data.LLaVADataset
+ data_path: MMAU/test.json
+MMAU-mini:
+ _target_: llava.data.LLaVADataset
+ data_path: MMAU/test-mini.json
+AudioEntailmentQA:
+ _target_: llava.data.LLaVADataset
+ data_path: AudioEntailmentQA/test.json
+SPGI-ASR:
+ _target_: llava.data.LLaVADataset
+ data_path: SPGI-ASR/val.json
+SWBD-ASR:
+ _target_: llava.data.LLaVADataset
+ data_path: SWBD-ASR/val.json
+LibriSpeech-ASR-clean:
+ _target_: llava.data.LLaVADataset
+ data_path: LibriSpeech-ASR/test_clean.json
+LibriSpeech-ASR-other:
+ _target_: llava.data.LLaVADataset
+ data_path: LibriSpeech-ASR/test_other.json
+VoxPopuli-ASR:
+ _target_: llava.data.LLaVADataset
+ data_path: VoxPopuli-ASR/test.json
+Europarl-ASR:
+ _target_: llava.data.LLaVADataset
+ data_path: Europarl-ASR/test.json
+CV-ASR:
+ _target_: llava.data.LLaVADataset
+ data_path: CV-ASR/test.json
+GigaSpeech-ASR:
+ _target_: llava.data.LLaVADataset
+ data_path: GigaSpeech-ASR/test.json
+CompA-R-AQA:
+ _target_: llava.data.LLaVADataset
+ data_path: CompA-R-AQA/test.json
+MuschoMusicQA:
+ _target_: llava.data.LLaVADataset
+ data_path: MuschoMusicQA/test.json
+CMM:
+ _target_: llava.data.LLaVADataset
+ data_path: CMM/test.json
+AIR-Bench:
+ _target_: llava.data.LLaVADataset
+ data_path: AIR-Bench/test.json
diff --git a/llava/data/registry/datasets/default.yaml b/llava/data/registry/datasets/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1dc18dfebe7b092e029ba6661c41925d503477ae
--- /dev/null
+++ b/llava/data/registry/datasets/default.yaml
@@ -0,0 +1,5 @@
+---
+dummy:
+ _target_: llava.data.DummyDataset
+ num_instances: 10000
+ comments: dummy dataset for testing
diff --git a/llava/data/registry/mixtures.yaml b/llava/data/registry/mixtures.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..70f69accb593595bd978de144a8b571b6c24b3b6
--- /dev/null
+++ b/llava/data/registry/mixtures.yaml
@@ -0,0 +1,78 @@
+---
+audio_speech_all:
+ -CV-ASR_1
+ -MELD-EmotionClassification+
+ -BBCSoundEffects-AudioDescription
+ -SWBD-ASR_1
+ -WavCaps-SoundBible-AudioCaptioning
+ -AudioSet-Speech-Audio-QA
+ -SONYC-UST-EventClassification
+ -VoxPopuli-ASR_1
+ -FSD50k-EventClassification
+ -SalmonnQA
+ -emov-db-EmotionClassification
+ -LLARK_MagnaTagATune-mir+tess-EmotionClassification
+ -Europarl-ASR_1
+ -jl-corpus-EmotionClassification
+ -Ego-10-AudioCaptioning
+ -SPGI-ASR_1
+ -CREMA-D-EmotionClassification
+ -MusicBenchQA
+ -WavCaps-BBC_Sound_Effects-AudioCaptioning
+ -NSynth-Instrument
+ -SpokenSquadQA
+ -NSynth-MIR
+ -AudioEntailmentQA
+ -GigaSpeech-ASR_1
+ -WavCaps-AudioSet_SL-AudioCaptioning
+ -NonSpeech7k-EventClassification
+ -chime-home-EventClassification
+ -MusicCaps-AudioCaptioning
+ -LP-MusicCaps-MSD-AudioCaptioning
+ -Ego-30-AudioCaptioning
+ -NSynth-Source+Clotho-v2-AudioCaptioning
+ -LP-MusicCaps-MC-AudioCaptioning
+ -Clotho-AQA-EventClassification
+ -WavCaps-FreeSound-AudioCaptioning
+ -LLARK_MagnaTagATune-reasoning
+ -AudioSet-Temporal-Speech-Audio-QA
+ -TUT-EventClassification
+ -ESC50-EventClassification
+ -WavText5K-Tagging
+ -MELD-SentimentClassification
+ -Music-AVQA-AQA_All
+ -Music-AVQA-AVQA_All
+ -MACS-AudioCaptioning
+ -Medley-solos-DB-InstrClassification
+ -AudioSet-EventClassification
+ -OMGEmotion-EmotionClassification
+ -FMA-GenreClassification
+ -Epidemic_sound-AudioCaptioning
+ -CochlScene-SceneClassification
+ -LLARK_FMA-reasoning
+ -ravdess-EmotionClassification
+ -CompA-R-AQA
+ -MU-LLAMA-AQA
+ -musdbhq-InstrClassification
+ -UrbanSound8K-EventClassification
+ -audiocaps-AudioCaptioning
+ -VocalSound-VocalClassification
+ -CLAP_freesound-AudioCaptioning
+ -MMAUQA
+ -SongDescriber-AudioCaptioning
+ -HeySQuADQA
+ -Mira-AudioCaptioning
+ -Clotho-AQA-AQA
+ -LibriSpeech-ASR_1
+ -IEMOCAP-EmotionClassification
+ -AudioSetFullwoAudioMusicCaps-EventClassification
+ -MSP-PODCAST-Publish-1.9-EmotionClassification
+ -OpenAQA-AQA
+ -SoundDescs-AudioDescription
+ -LibriSQA
+ -LLARK_FMA-mir
+ -LP-MusicCaps-MTT-AudioCaptioning
+ -GTZAN-GenreClassification
+ -musdbhq-captioning
+ -YesNoQA
+
diff --git a/llava/entry.py b/llava/entry.py
new file mode 100644
index 0000000000000000000000000000000000000000..44348e95eea2598a030f6953ff2098e133bed68f
--- /dev/null
+++ b/llava/entry.py
@@ -0,0 +1,60 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import os
+import typing
+from typing import List, Optional
+
+if typing.TYPE_CHECKING:
+ from transformers import PreTrainedModel
+else:
+ PreTrainedModel = None
+
+__all__ = ["load"]
+
+
+def load(
+ model_path: str,
+ model_base: Optional[str] = None,
+ devices: Optional[List[int]] = None,
+ **kwargs,
+) -> PreTrainedModel:
+ import torch
+
+ from llava.conversation import auto_set_conversation_mode
+ from llava.mm_utils import get_model_name_from_path
+ from llava.model.builder import load_pretrained_model
+
+ auto_set_conversation_mode(model_path)
+
+ model_name = get_model_name_from_path(model_path)
+ model_path = os.path.expanduser(model_path)
+ if os.path.exists(os.path.join(model_path, "model")):
+ model_path = os.path.join(model_path, "model")
+
+ # Set `max_memory` to constrain which GPUs to use
+ if devices is not None:
+ assert "max_memory" not in kwargs, "`max_memory` should not be set when `devices` is set"
+ kwargs.update(max_memory={device: torch.cuda.get_device_properties(device).total_memory for device in devices})
+
+ model = load_pretrained_model(model_path, model_name, model_base, **kwargs)[1]
+ return model
diff --git a/llava/eval/__init__.py b/llava/eval/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d807fb5a2b32ea471a6bfec0d57694f57f91e377
--- /dev/null
+++ b/llava/eval/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import os
+
+from llava.utils import io
+
+__all__ = ["EVAL_ROOT", "TASKS"]
+
+
+EVAL_ROOT = "scripts/eval"
+TASKS = io.load(os.path.join(os.path.dirname(__file__), "registry_audio.yaml"))
diff --git a/llava/eval/eval_audio_bench.py b/llava/eval/eval_audio_bench.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3017aef3ba10519f507e4d878d834c7ab3d7dbf
--- /dev/null
+++ b/llava/eval/eval_audio_bench.py
@@ -0,0 +1,117 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import argparse
+import csv
+import itertools
+import json
+import os
+
+import torch
+from datasets import load_dataset
+from tqdm import tqdm
+
+import llava
+from llava import conversation as conversation_lib
+from llava.data.builder import DATASETS
+from llava.eval.mmmu_utils.eval_utils import parse_choice
+from llava.utils import distributed as dist
+from llava.utils import io
+from llava.utils.logging import logger
+
+
+def load_existing_ids(output_file):
+ if not os.path.exists(output_file):
+ return set(), []
+ try:
+ with open(output_file, "r") as f:
+ lines = f.readlines()
+ outputs = [json.loads(line) for line in lines]
+ processed_ids = {item["id"] for item in outputs}
+ return processed_ids, outputs
+ except Exception as e:
+ print(f"Error loading existing outputs: {e}")
+ return set(), []
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, default=None)
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--task", type=str, default=None)
+ parser.add_argument("--conv-mode", type=str, default="auto")
+ parser.add_argument("--generation-config", type=json.loads)
+ parser.add_argument("--output-dir", type=str, default=None)
+ args = parser.parse_args()
+
+ # Set up distributed environment
+ dist.init()
+ devices = range(dist.local_rank(), torch.cuda.device_count(), dist.local_size())
+ torch.cuda.set_device(devices[0])
+
+ # Load stage 3 model with line 56
+ model = llava.load(args.model_base, model_base=None, devices=devices)
+ # Uncomment line 58-63 to load stage 3.5 model on top of stage 3 for thinking mode and long audio mode
+ # model = PeftModel.from_pretrained(
+ # model,
+ # args.model_path,
+ # device_map="auto",
+ # torch_dtype=torch.float16,
+ # )
+ # Set up generation config
+ generation_config = model.default_generation_config
+ if args.generation_config is not None:
+ generation_config.update(**args.generation_config)
+
+ # Load data and chunk it
+ json_file = DATASETS[args.task]["data_path"]
+ instances = io.load(json_file)
+ instances = instances[dist.rank() :: dist.size()]
+
+ output_path = os.path.join(args.output_dir, f"outputs_{args.task}.jsonl")
+ processed_ids, outputs = load_existing_ids(output_path)
+
+ count = len(outputs)
+ # Run inference
+ new_outputs = []
+ for instance in tqdm(instances, disable=not dist.is_main()):
+ uuid = instance["id"]
+ sound_path = instance["sound"]
+
+ if sound_path in processed_ids:
+ continue # Skip if already processed
+ sound = llava.Sound(sound_path)
+ conversations = instance["conversations"]
+ question = conversations[0]["value"]
+
+ response = model.generate_content([sound, question], generation_config=generation_config)
+
+ print("response", response)
+
+ output = {"id": sound_path, "question": question, "gt_answer": conversations[1]["value"], "pred": response}
+ new_outputs.append(output)
+ count = count +1
+ if count % 20 == 0:
+ # Gather and save outputs
+ if dist.size() > 1:
+ outputs_new = dist.gather(new_outputs, dst=0)
+ if dist.is_main():
+ outputs_new = list(itertools.chain(*outputs_new))
+ final_outputs = outputs + outputs_new
+ io.save(os.path.join(args.output_dir, f"outputs_{args.task}.jsonl"), final_outputs)
+ else:
+ final_outputs = outputs + new_outputs
+ io.save(os.path.join(args.output_dir, f"outputs_{args.task}.jsonl"), final_outputs)
+ if dist.size() > 1:
+ new_outputs = dist.gather(new_outputs, dst=0)
+ if not dist.is_main():
+ return
+ new_outputs = list(itertools.chain(*new_outputs))
+ final_outputs = outputs + new_outputs
+ io.save(os.path.join(args.output_dir, "outputs_"+str(args.task)+".jsonl"), final_outputs)
+
+if __name__ == "__main__":
+ main()
diff --git a/llava/eval/mmmu_utils/__pycache__/eval_utils.cpython-311.pyc b/llava/eval/mmmu_utils/__pycache__/eval_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..62c020ee5b25a39cc6679b0f7f28235d1db3120c
Binary files /dev/null and b/llava/eval/mmmu_utils/__pycache__/eval_utils.cpython-311.pyc differ
diff --git a/llava/eval/mmmu_utils/eval_utils.py b/llava/eval/mmmu_utils/eval_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..42c3cc6bc917ffa3f5af7ed7c2ec601ef8dbfbc1
--- /dev/null
+++ b/llava/eval/mmmu_utils/eval_utils.py
@@ -0,0 +1,61 @@
+# This file is originated from the official MMMU codebase:
+# https://github.com/MMMU-Benchmark/MMMU
+
+import random
+
+import numpy as np
+
+
+def parse_choice(response, all_choices, index2ans=None):
+ """
+ Parse the prediction from the generated response.
+ Return the predicted index e.g., A, B, C, D.
+ """
+ for char in [",", ".", "!", "?", ";", ":", "'"]:
+ response = response.strip(char)
+ response = " " + response + " " # add space to avoid partial match
+
+ index_ans = True
+ ans_with_brack = False
+ candidates = []
+ for choice in all_choices: # e.g., (A) (B) (C) (D)
+ if f"({choice})" in response:
+ candidates.append(choice)
+ ans_with_brack = True
+
+ if len(candidates) == 0:
+ for choice in all_choices: # e.g., A B C D
+ if f" {choice} " in response:
+ candidates.append(choice)
+
+ # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
+ if len(candidates) == 0 and len(response.split()) > 5 and index2ans is not None:
+ for index, ans in index2ans.items():
+ if ans.lower() in response.lower():
+ candidates.append(index)
+ index_ans = False # it's content ans.
+
+ if len(candidates) == 0: # still not get answer, randomly choose one.
+ pred_index = random.choice(all_choices)
+ elif len(candidates) > 1:
+ start_indexes = []
+ if index_ans:
+ if ans_with_brack:
+ for can in candidates:
+ index = response.rfind(f"({can})")
+ start_indexes.append(index) # -1 will be ignored anyway
+ # start_indexes = [generated_response.index(f'({can})') for can in candidates]
+ else:
+ for can in candidates:
+ index = response.rfind(f" {can} ")
+ start_indexes.append(index)
+ else:
+ for can in candidates:
+ index = response.lower().rfind(index2ans[can].lower())
+ start_indexes.append(index)
+ # get the last one
+ pred_index = candidates[np.argmax(start_indexes)]
+ else: # if only one candidate, use it.
+ pred_index = candidates[0]
+
+ return pred_index
diff --git a/llava/eval/registry_audio.yaml b/llava/eval/registry_audio.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8de55c73c3d12785044fb0cbd09c847eb548a514
--- /dev/null
+++ b/llava/eval/registry_audio.yaml
@@ -0,0 +1,93 @@
+Clotho-AQA-AQA:
+ tags:
+ - local
+Music-AVQA-AQA_All:
+ tags:
+ - local
+CochlScene-SceneClassification:
+ tags:
+ - local
+NSynth-Source:
+ tags:
+ - local
+NSynth-Instrument:
+ tags:
+ - local
+FSD50k-EventClassification:
+ tags:
+ - local
+Clotho-v2-AudioCaptioning:
+ tags:
+ - local
+audiocaps-AudioCaptioning:
+ tags:
+ - local
+ravdess-EmotionClassification:
+ tags:
+ - local
+GTZAN-GenreClassification:
+ tags:
+ - local
+UrbanSound8K-EventClassification:
+ tags:
+ - local
+Medley-solos-DB-InstrClassification:
+ tags:
+ - local
+ESC50-EventClassification:
+ tags:
+ - local
+CREMA-D-EmotionClassification:
+ tags:
+ - local
+IEMOCAP-EmotionClassification:
+ tags:
+ - local
+MELD-EmotionClassification:
+ tags:
+ - local
+MELD-SentimentClassification:
+ tags:
+ - local
+MMAU:
+ tags:
+ - local
+AudioEntailmentQA:
+ tags:
+ - local
+SPGI-ASR:
+ tags:
+ - local
+SWBD-ASR:
+ tags:
+ - local
+LibriSpeech-ASR-clean:
+ tags:
+ - local
+LibriSpeech-ASR-other:
+ tags:
+ - local
+VoxPopuli-ASR:
+ tags:
+ - local
+Europarl-ASR:
+ tags:
+ - local
+CV-ASR:
+ tags:
+ - local
+GigaSpeech-ASR:
+ tags:
+ - local
+CompA-R-AQA:
+ tags:
+ - local
+MuschoMusicQA:
+ tags:
+ - local
+CMM:
+ tags:
+ - local
+AIR-Bench:
+ tags:
+ - local
diff --git a/llava/media.py b/llava/media.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bcc0aa44034212a4649b9ef256eb5d995e48380
--- /dev/null
+++ b/llava/media.py
@@ -0,0 +1,47 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+__all__ = ["Media", "File", "Image", "Video", "Speech", "Sound"]
+
+
+class Media:
+ pass
+
+
+class File(Media):
+ def __init__(self, path: str) -> None:
+ self.path = path
+
+
+class Image(File):
+ pass
+
+
+class Video(File):
+ pass
+
+
+class Speech(File):
+ pass
+
+class Sound(File):
+ pass
\ No newline at end of file
diff --git a/llava/mm_utils.py b/llava/mm_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f4a23d338f9a4440aeaa6f44f7f5fb0d89093fe
--- /dev/null
+++ b/llava/mm_utils.py
@@ -0,0 +1,641 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# dynamic_preprocess and find_closest_aspect_ratio are referenced from https://github.com/OpenGVLab/InternVL
+
+import base64
+import os
+import tempfile
+from io import BytesIO
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import StoppingCriteria
+
+from pydub import AudioSegment
+from torchvision import transforms
+import soundfile as sf
+from librosa import resample as librosa_resample
+import whisper
+import random
+from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, UniformClipSampler
+DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds
+def get_frame_from_vcap(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
+ import cv2
+
+ if fps == None or frame_count == None:
+ # if one of fps or frame_count is None, still recompute
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
+ frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
+ if fps == 0 or frame_count == 0:
+ print(f"Video file not found. return empty images. {video_file_name}")
+ return [
+ Image.new("RGB", (720, 720)),
+ ] * num_frames, 0, [0.]
+
+ duration = frame_count / fps
+ frame_interval = frame_count // num_frames
+ if frame_interval == 0 and frame_count <= 1:
+ print(f"frame_interval is equal to 0. return empty image. {video_file_name}")
+ return [
+ Image.new("RGB", (720, 720)),
+ ] * num_frames, 0, [0.]
+ # print("duration:", duration, "frames:", frame_count, "intervals:", frame_interval)
+
+ images = []
+ count = 0
+ success = True
+ frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
+ frame_times = [frame / fps for frame in frame_indices]
+ while success:
+ # print("frame_count:", frame_count, "count:", count, "num_frames:", num_frames, "frame_interval:", frame_interval)
+ if frame_count >= num_frames:
+ success, frame = vidcap.read()
+ if count in frame_indices:
+ try:
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ im_pil = Image.fromarray(img)
+ images.append(im_pil)
+ except BaseException:
+ continue
+ if len(images) >= num_frames:
+ return images, num_frames, frame_times
+ count += 1
+ else:
+ # Left padding frames if the video is not long enough
+ success, frame = vidcap.read()
+ if success:
+ try:
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ im_pil = Image.fromarray(img)
+ images.append(im_pil)
+ except BaseException:
+ continue
+ count += 1
+ else:
+ break
+ if len(images) == 0:
+ raise ValueError("Did not find enough frames in the video. return empty image.")
+
+ return images, len(images), frame_times
+
+
+def get_frame_from_vcap_with_fps(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
+ """
+ num_frames is the max number of frames the model can support.
+ frame_count is the number of frames in the input video.
+ max_fps is the max FPS of the model can support.
+ fps is the fps of the input video.
+ """
+
+ import random
+
+ import cv2
+
+ if fps == None or frame_count == None:
+ # if one of fps or frame_count is None, still recompute
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
+ frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
+
+ if fps == 0 or frame_count == 0:
+ print(f"Video file not found. return empty images. {video_file_name}")
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
+ return [
+ Image.new("RGB", (720, 720)),
+ ] * empty_video_frames, 0, [0.]
+
+ duration = frame_count / fps
+ # print("duration:", duration, "frames:", frame_count, "fps:", fps, "num_frames:", num_frames, "max_fps:", max_fps)
+ # If the video is too long (longer than max_fps and num_frames can support),
+ # we will use lower fps to sample frames.
+ if duration >= num_frames / max_fps:
+ frame_interval = frame_count // num_frames
+
+ # If the video is too short, we will skip the video if there is only one frame.
+ if frame_interval == 0 and frame_count <= 1:
+ print(f"frame_interval is equal to 0. return empty image. {video_file_name}")
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
+ return [
+ Image.new("RGB", (720, 720)),
+ ] * empty_video_frames, 0, [0.]
+
+ images = []
+ count = 0
+ success = True
+ frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
+ frame_times = [frame / fps for frame in frame_indices]
+ while success:
+ if frame_count >= num_frames:
+ # success, frame = vidcap.read()
+ if count in frame_indices:
+ success, frame = vidcap.read()
+ try:
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ im_pil = Image.fromarray(img)
+ images.append(im_pil)
+ except:
+ # print("Failed to read frame:", count)
+ continue
+ if len(images) >= num_frames:
+ return images, num_frames, frame_times
+ else:
+ success = vidcap.grab()
+ count += 1
+ else:
+ # Left padding frames if the video is not long enough
+ success, frame = vidcap.read()
+ if success:
+ try:
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ im_pil = Image.fromarray(img)
+ images.append(im_pil)
+ except:
+ # print("Failed to read frame:", count)
+ continue
+ count += 1
+ else:
+ break
+ else:
+ frames_required = int(duration * max_fps)
+ frame_indices = np.linspace(0, frame_count - 1, frames_required, dtype=int)
+ if frames_required == 0:
+ print(f"frames_required is fewer than 2. Duration {duration}, return empty image.")
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
+ return [
+ Image.new("RGB", (720, 720)),
+ ] * empty_video_frames, 0, [0.]
+ elif frames_required == 1:
+ frame_indices = np.linspace(0, frame_count - 1, 2, dtype=int)
+ images = []
+ count = 0
+ looked = 0
+ success = True
+
+ while success:
+ success, frame = vidcap.read()
+ if success and (looked in frame_indices):
+ try:
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ im_pil = Image.fromarray(img)
+ images.append(im_pil)
+ except:
+ continue
+ count += 1
+ looked += 1
+ frame_times = [frame / fps for frame in frame_indices]
+
+ if len(images) == 0:
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
+ return [
+ Image.new("RGB", (720, 720)),
+ ] * empty_video_frames, 0, [0.]
+ else:
+ return images, len(images), frame_times
+
+
+def opencv_extract_frames(vpath_or_bytesio, frames=6, max_fps=0.0, fps=None, frame_count=None):
+ """
+ Extract frames from a video using OpenCV.
+
+ Args:
+ vpath_or_bytesio (str or BytesIO): Path to the video file or BytesIO object containing the video.
+ frames (int): Number of frames to extract from the video.
+ fps (float): Frames per second of the video. If 0.0, the function will extract frames at equal intervals.
+
+ Returns:
+ list: List of PIL Images extracted from the video.
+
+ Raises:
+ NotImplementedError: If the type of `vpath_or_bytesio` is not supported.
+ """
+ import cv2
+ if isinstance(vpath_or_bytesio, str):
+ vidcap = cv2.VideoCapture(vpath_or_bytesio)
+ if max_fps > 0.0:
+ return get_frame_from_vcap_with_fps(
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio
+ )
+ return get_frame_from_vcap(
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio
+ )
+ elif isinstance(vpath_or_bytesio, (BytesIO,)):
+ # assuming mp4
+ with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video:
+ temp_video.write(vpath_or_bytesio.read())
+ temp_video_name = temp_video.name
+ vidcap = cv2.VideoCapture(temp_video_name)
+ if max_fps > 0.0:
+ return get_frame_from_vcap_with_fps(
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name
+ )
+ return get_frame_from_vcap(
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name
+ )
+ else:
+ raise NotImplementedError(type(vpath_or_bytesio))
+
+
+def load_image_from_base64(image):
+ return Image.open(BytesIO(base64.b64decode(image)))
+
+
+def expand2square(pil_img, background_color):
+ """
+ Expand the given PIL image to a square shape by adding padding.
+
+ Parameters:
+ - pil_img: The PIL image to be expanded.
+ - background_color: The color of the padding to be added.
+
+ Returns:
+ - The expanded PIL image.
+
+ If the image is already square, it is returned as is.
+ If the image is wider than it is tall, padding is added to the top and bottom.
+ If the image is taller than it is wide, padding is added to the left and right.
+ """
+ width, height = pil_img.size
+ if pil_img.mode == "L":
+ background_color = background_color[0]
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+
+def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
+ best_ratio_diff = float("inf")
+ best_ratio = (1, 1)
+ area = width * height
+ for ratio in target_ratios:
+ target_aspect_ratio = ratio[0] / ratio[1]
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
+ if ratio_diff < best_ratio_diff:
+ best_ratio_diff = ratio_diff
+ best_ratio = ratio
+ elif ratio_diff == best_ratio_diff:
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
+ best_ratio = ratio
+ return best_ratio
+
+
+def dynamic_preprocess(image, min_num=1, max_num=12, image_size=384, use_thumbnail=True):
+ orig_width, orig_height = image.size
+ aspect_ratio = orig_width / orig_height
+
+ # calculate the existing image aspect ratio
+ target_ratios = {
+ (i, j)
+ for n in range(min_num, max_num + 1)
+ for i in range(1, n + 1)
+ for j in range(1, n + 1)
+ if i * j <= max_num and i * j >= min_num
+ }
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
+
+ # find the closest aspect ratio to the target
+ target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
+
+ # calculate the target width and height
+ target_width = image_size * target_aspect_ratio[0]
+ target_height = image_size * target_aspect_ratio[1]
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
+
+ # resize the image
+ resized_img = image.resize((target_width, target_height))
+ processed_images = []
+ for i in range(blocks):
+ box = (
+ (i % (target_width // image_size)) * image_size,
+ (i // (target_width // image_size)) * image_size,
+ ((i % (target_width // image_size)) + 1) * image_size,
+ ((i // (target_width // image_size)) + 1) * image_size,
+ )
+ # split the image
+ split_img = resized_img.crop(box)
+ processed_images.append(split_img)
+ assert len(processed_images) == blocks
+ if use_thumbnail and len(processed_images) != 1:
+ thumbnail_img = image.resize((image_size, image_size))
+ processed_images.append(thumbnail_img)
+ return processed_images
+
+
+def dynamic_s2_preprocess(image, s2_scales=[384, 768, 1152], max_num=12, image_size=384):
+ orig_width, orig_height = image.size
+ aspect_ratio = orig_width / orig_height
+ min_num = (s2_scales[-1] // s2_scales[0]) ** 2 # at least use number of tiles as the largest scale
+
+ processed_images = []
+
+ ##########################################################################################
+ ############# Add tiles for all but the last scale using fixed squre ratio ###############
+ ##########################################################################################
+
+ for scale in s2_scales[:-1]:
+ target_width = image_size * (scale // s2_scales[0])
+ target_height = image_size * (scale // s2_scales[0])
+ blocks = (scale // s2_scales[0]) ** 2
+
+ # resize the image
+ resized_img = image.resize((target_width, target_height))
+ for i in range(blocks):
+ box = (
+ (i % (target_width // image_size)) * image_size,
+ (i // (target_width // image_size)) * image_size,
+ ((i % (target_width // image_size)) + 1) * image_size,
+ ((i // (target_width // image_size)) + 1) * image_size,
+ )
+ # split the image
+ split_img = resized_img.crop(box)
+ processed_images.append(split_img)
+
+ ##########################################################################################
+ ################ Add tiles for the last scale using dynamic aspect ratio #################
+ ##########################################################################################
+
+ # calculate the existing image aspect ratio
+ target_ratios = {
+ (i, j)
+ for n in range(min_num, max_num + 1)
+ for i in range(1, n + 1)
+ for j in range(1, n + 1)
+ if i * j <= max_num and i * j >= min_num
+ }
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
+
+ # find the closest aspect ratio to the target
+ target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
+
+ # calculate the target width and height
+ target_width = image_size * target_aspect_ratio[0]
+ target_height = image_size * target_aspect_ratio[1]
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
+
+ # resize the image
+ resized_img = image.resize((target_width, target_height))
+ for i in range(blocks):
+ box = (
+ (i % (target_width // image_size)) * image_size,
+ (i // (target_width // image_size)) * image_size,
+ ((i % (target_width // image_size)) + 1) * image_size,
+ ((i // (target_width // image_size)) + 1) * image_size,
+ )
+ # split the image
+ split_img = resized_img.crop(box)
+ processed_images.append(split_img)
+
+ return processed_images, (target_aspect_ratio[1], target_aspect_ratio[0])
+
+
+
+def dynamic_s2_process_images_and_prompt(images, data_args, image_folder=None):
+ idx = 0
+ all_images = []
+ all_block_size = []
+ for img in images:
+ processed_images, block_size = process_image(img, data_args, image_folder, enable_dynamic_s2=True)
+ all_images.append(processed_images)
+ all_block_size.append(block_size)
+ idx += 2
+ if all_images:
+ all_images = torch.cat(all_images)
+ else:
+ all_images = None
+ return all_images, all_block_size
+
+
+def process_image(
+ image_file, data_args, image_folder, enable_dynamic_res=False, enable_dynamic_s2=False, max_tiles=None
+):
+ processor = data_args.image_processor
+ if isinstance(image_file, str):
+ if image_folder is not None:
+ image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
+ else:
+ image = Image.open(image_file).convert("RGB")
+ else:
+ # image is stored in bytearray
+ image = image_file
+ image = image.convert("RGB")
+ if hasattr(data_args.image_processor, "crop_size"):
+ # CLIP vision tower
+ crop_size = data_args.image_processor.crop_size
+ else:
+ # SIGLIP vision tower
+ assert hasattr(data_args.image_processor, "size")
+ crop_size = data_args.image_processor.size
+ if "dynamic_s2" in data_args.image_aspect_ratio and enable_dynamic_s2:
+ assert crop_size["height"] == crop_size["width"]
+ images, block_size = dynamic_s2_preprocess(
+ image, s2_scales=data_args.s2_scales, max_num=data_args.max_tiles, image_size=crop_size["height"]
+ )
+ images = [processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images]
+ return torch.stack(images), block_size
+ if "dynamic" in data_args.image_aspect_ratio and enable_dynamic_res:
+ assert crop_size["height"] == crop_size["width"]
+ if max_tiles is not None:
+ max_num = max_tiles
+ else:
+ max_num = data_args.max_tiles
+ images = dynamic_preprocess(image, min_num=data_args.min_tiles, max_num=max_num, image_size=crop_size["height"])
+ images = [processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images]
+ return torch.stack(images)
+
+ if data_args.image_aspect_ratio == "resize":
+ image = image.resize((crop_size["width"], crop_size["height"]))
+ if data_args.image_aspect_ratio == "pad":
+
+ def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+ image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
+ image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
+ else:
+ # Using default behavior of the vision encoder
+ # For CLIP, default is central crop
+ # For Radio, default is central crop
+ # For Siglip, default is resize
+ # For InternVIT, default is resize
+ image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
+ return image
+
+def get_num_windows(T, sr, max_num_window=5):
+
+ window_length = int(30.0 * sr)
+ window_overlap = int(0.0 * sr)
+ max_num_window = max_num_window
+
+ num_windows = 1
+ if T <= window_length:
+ num_windows = 1
+ full_length = window_length
+ elif T >= (max_num_window * window_length - (max_num_window - 1) * window_overlap):
+ num_windows = max_num_window
+ full_length = (max_num_window * window_length - (max_num_window - 1) * window_overlap)
+ else:
+ num_windows = 1 + int(np.ceil((T - window_length) / float(window_length - window_overlap)))
+ full_length = num_windows * window_length - (num_windows - 1) * window_overlap
+
+ return num_windows, full_length
+
+def load_audio(file_path, target_sr=16000, duration=30.0, start=0.0):
+ if file_path.endswith('.mp3'):
+ audio = AudioSegment.from_file(file_path)
+ if len(audio) > (start + duration) * 1000:
+ audio = audio[start * 1000:(start + duration) * 1000]
+
+ if audio.frame_rate != target_sr:
+ audio = audio.set_frame_rate(target_sr)
+
+ if audio.channels > 1:
+ audio = audio.set_channels(1)
+
+ data = np.array(audio.get_array_of_samples())
+ if audio.sample_width == 2:
+ data = data.astype(np.float32) / np.iinfo(np.int16).max
+ elif audio.sample_width == 4:
+ data = data.astype(np.float32) / np.iinfo(np.int32).max
+ else:
+ raise ValueError("Unsupported bit depth: {}".format(audio.sample_width))
+
+ else:
+ with sf.SoundFile(file_path) as audio:
+ original_sr = audio.samplerate
+ channels = audio.channels
+
+ max_frames = int((start + duration) * original_sr)
+
+ audio.seek(int(start * original_sr))
+ frames_to_read = min(max_frames, len(audio))
+ data = audio.read(frames_to_read)
+
+ if data.max() > 1 or data.min() < -1:
+ data = data / max(abs(data.max()), abs(data.min()))
+
+ if original_sr != target_sr:
+ if channels == 1:
+ data = librosa_resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr)
+ else:
+ data = librosa_resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0]
+ else:
+ if channels != 1:
+ data = data.T[0]
+
+ if data.min() >= 0:
+ data = 2 * data / abs(data.max()) - 1.0
+ else:
+ data = data / max(abs(data.max()), abs(data.min()))
+
+ assert len(data.shape) == 1, data.shape
+ return data
+
+def process_images(images, image_processor, model_cfg, enable_dynamic_res=False, max_tiles=None):
+ model_cfg.image_processor = image_processor
+ new_images = [
+ process_image(image, model_cfg, None, enable_dynamic_res=enable_dynamic_res, max_tiles=max_tiles)
+ for image in images
+ ]
+
+ if all(x.shape == new_images[0].shape for x in new_images):
+ if len(new_images[0].shape) == 4:
+ new_images = torch.cat(new_images, dim=0)
+ elif len(new_images[0].shape) == 3:
+ new_images = torch.stack(new_images, dim=0)
+ else:
+ raise ValueError(f"new_images rank does not equal to 4, rank: {len(new_images[0].shape)}")
+ else:
+ raise ValueError("The shape of images in new_images is different!")
+ return new_images
+
+def process_sounds(sounds):
+ sounds = torch.tensor(sounds)
+ return sounds
+
+def process_sound_masks(masks):
+ masks = torch.tensor(masks[0])
+ return masks
+
+def tokenizer_image_token(prompt, tokenizer, return_tensors=None):
+ return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
+
+def is_gemma_tokenizer(tokenizer):
+ return "gemma" in tokenizer.__class__.__name__.lower()
+
+
+def get_model_name_from_path(model_path):
+ model_path = model_path.strip("/")
+ model_paths = model_path.split("/")
+ if model_paths[-1].startswith("checkpoint-"):
+ return model_paths[-2] + "_" + model_paths[-1]
+ else:
+ return model_paths[-1]
+
+class KeywordsStoppingCriteria(StoppingCriteria):
+ def __init__(self, keywords, tokenizer, input_ids):
+ self.keywords = keywords
+ self.keyword_ids = []
+ self.max_keyword_len = 0
+ for keyword in keywords:
+ cur_keyword_ids = tokenizer(keyword).input_ids
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
+ cur_keyword_ids = cur_keyword_ids[1:]
+ if len(cur_keyword_ids) > self.max_keyword_len:
+ self.max_keyword_len = len(cur_keyword_ids)
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
+ self.tokenizer = tokenizer
+ self.start_len = input_ids.shape[1]
+
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
+ for keyword_id in self.keyword_ids:
+ if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all():
+ return True
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
+ for keyword in self.keywords:
+ if keyword in outputs:
+ return True
+ return False
+
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
+ outputs = []
+ for i in range(output_ids.shape[0]):
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
+ return all(outputs)
diff --git a/llava/model/FloatPointQuantizeTorch.py b/llava/model/FloatPointQuantizeTorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..b01c6f8ad008cdcd01c51a85399a194032df05b1
--- /dev/null
+++ b/llava/model/FloatPointQuantizeTorch.py
@@ -0,0 +1,85 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import math
+
+import torch
+
+
+def floatExMy_quantize_torch(x, e_bit, m_bit, stochastic):
+ sign, x_abs = x.sign(), x.abs()
+ Elow, Ehigh, Mhigh = -(2 ** (e_bit - 1)) + 2, 2 ** (e_bit - 1), 2**m_bit
+ expo = torch.floor(torch.log2(x_abs))
+ expo = torch.clamp(expo, min=Elow, max=Ehigh)
+ mant = x_abs / torch.exp2(expo)
+
+ mant_int = torch.floor(mant)
+ mant_frac = mant - mant_int
+ mant_frac = mant_frac * Mhigh
+ if stochastic:
+ noise = mant_frac.new(mant_frac.shape).uniform_(-0.5, 0.5)
+ mant_frac.add_(noise)
+ mant_frac = torch.round(mant_frac)
+
+ mant_q = mant_int + mant_frac / Mhigh
+ y = sign * (2**expo) * mant_q
+ y = y.to(x)
+
+ return y
+
+
+def floatExM0_quantize_torch(x, e_bit, stochastic):
+ sign, x_abs = x.sign(), x.abs()
+ Elow, Ehigh = -(2 ** (e_bit - 1)) + 1, 2 ** (e_bit - 1)
+ expo = torch.log2(x_abs)
+ if stochastic:
+ noise = expo.new(expo.shape).uniform_(-0.5, 0.5)
+ expo.add(noise)
+ log_bias = math.log2(4 / 3) - 1 / 2
+ expo.add(torch.ones_like(expo) * log_bias)
+ expo = torch.clamp(expo, min=Elow - 1, max=Ehigh)
+ expo = torch.round(expo)
+
+ y = sign * (2**expo) * (expo > Elow) # When underflow, set the value to 0
+ y = y.to(x)
+
+ return y
+
+
+def Dynamic_quantize_torch(x, bit, stochastic):
+ if stochastic:
+ raise NotImplementedError("Dynamic Tree quantization does not support stochastic")
+ sign, x_abs = x.sign(), x.abs()
+ expo = torch.ceil(torch.log10(x_abs))
+ expo = torch.clamp(expo, min=2 - bit)
+ mant = (10 * x_abs / torch.pow(10, expo) - 1) / 9 # Range from 0 - 1
+
+ mant_frac = mant * 2 ** (bit - 2 - expo.abs())
+ mant_frac = torch.round(mant_frac)
+ mant_frac = mant_frac / (2 ** (bit - 2 - expo.abs())) * 9 + 1
+ y = sign * (10**expo) * mant_frac / 10
+
+ zero_mask = y.abs() > 1.01 * 10 ** (1 - bit)
+ y = y * zero_mask
+ y = y.to(x)
+ return y
+
+
+def ZeroDynamic_quantize_torch(x, bit, stochastic):
+ if stochastic:
+ raise NotImplementedError("Dynamic Tree quantization does not support stochastic")
+ sign, x_abs = x.sign(), x.abs()
+ expo = torch.ceil(torch.log10(x_abs))
+ expo = torch.clamp(expo, min=2 - bit)
+ mant = (10 * x_abs / torch.pow(10, expo) - 1) / 9 # Range from 0 - 1
+
+ mant_frac = mant * 2 ** (bit - 2 - expo.abs())
+ mant_frac = torch.round(mant_frac)
+ mant_frac = mant_frac / (2 ** (bit - 2 - expo.abs())) * 9 + 1
+ y = sign * (10**expo) * mant_frac / 10
+
+ y = y.to(x)
+ return y
diff --git a/llava/model/FloatPointQuantizeTriton.py b/llava/model/FloatPointQuantizeTriton.py
new file mode 100644
index 0000000000000000000000000000000000000000..e619c3739ec968c92460b3d8d9efd2736fd07d44
--- /dev/null
+++ b/llava/model/FloatPointQuantizeTriton.py
@@ -0,0 +1,199 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import math
+import struct
+
+import numpy as np
+import torch
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+segment_size = 1024**3
+
+
+def floatExMy_quantize_triton(x, e_bit, m_bit, stochastic):
+ x_ori_shape = x.shape
+ x = x.view(-1)
+
+ n_elements = x.numel()
+
+ if n_elements <= segment_size:
+ grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
+ y = torch.empty_like(x)
+
+ if x.dtype in [torch.bfloat16, torch.float32]:
+ if stochastic:
+ noise = x.new(x.shape).uniform_(-0.5, 0.5)
+ _floatExMy_stochastic_quantize_kernel[grid](x, noise, y, n_elements, e_bit, m_bit)
+ else:
+ _floatExMy_quantize_kernel[grid](x, y, n_elements, e_bit, m_bit)
+ torch.cuda.synchronize()
+ else:
+ raise NotImplementedError(f"Other data format {x.dtype} for float quantization triton")
+ else: # Triton will break when x.numel > 2 * 1024 ** 3
+ num_segments = n_elements // segment_size + 1
+ split_size = [segment_size] * (num_segments - 1) + [n_elements - segment_size * (num_segments - 1)]
+ x_list = x.split(split_size)
+ y_list = []
+ del x
+
+ for x in x_list:
+ n_elements = x.numel()
+ grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
+ y = torch.empty_like(x)
+
+ if x.dtype in [torch.bfloat16, torch.float32]:
+ if stochastic:
+ noise = x.new(x.shape).uniform_(-0.5, 0.5)
+ _floatExMy_stochastic_quantize_kernel[grid](x, noise, y, n_elements, e_bit, m_bit)
+ else:
+ _floatExMy_quantize_kernel[grid](x, y, n_elements, e_bit, m_bit)
+ torch.cuda.synchronize()
+ else:
+ raise NotImplementedError(f"Other data format {x.dtype} for float quantization triton")
+
+ y_list.append(y)
+ y = torch.concat(y_list)
+ del y_list
+
+ y = y.reshape(x_ori_shape)
+ return y
+
+
+@triton.autotune(
+ configs=[
+ # triton.Config({'BLOCK_SIZE': 4,}, num_warps=4),
+ triton.Config(
+ {
+ "BLOCK_SIZE": 1024,
+ },
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE": 2048,
+ },
+ num_warps=4,
+ ),
+ ],
+ key=["n_elements"],
+)
+@triton.jit
+def _floatExMy_quantize_kernel(
+ x_ptr,
+ output_ptr,
+ n_elements,
+ e_bit,
+ m_bit,
+ BLOCK_SIZE: tl.constexpr,
+):
+ if isinstance(e_bit, tl.constexpr):
+ ebit = e_bit.value
+ else:
+ ebit = e_bit
+
+ if isinstance(m_bit, tl.constexpr):
+ mbit = m_bit.value
+ else:
+ mbit = m_bit
+
+ pid = tl.program_id(axis=0)
+ block_start = pid * BLOCK_SIZE
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < n_elements
+ x = tl.load(x_ptr + offsets, mask=mask)
+
+ x = x.to(tl.float32)
+ sign = 1 - 2 * libdevice.signbit(x)
+ x_abs = tl.abs(x)
+ Elow = -tl.exp2((ebit - 1).to(tl.float32)) + 2
+ Ehigh = tl.exp2((ebit - 1).to(tl.float32))
+ Mhigh = tl.exp2(mbit.to(tl.float32))
+ expo = tl.floor(tl.log2(x_abs))
+ expo = tl.clamp(expo, min=Elow, max=Ehigh)
+ mant = x_abs / tl.exp2(expo)
+
+ mant_int = tl.floor(mant)
+ mant_frac = mant - mant_int
+ mant_frac = mant_frac * Mhigh
+ # mant_frac = mant_frac + noise
+ mant_frac = libdevice.round(mant_frac)
+
+ mant_q = mant_int + mant_frac / Mhigh
+ y = sign * tl.exp2(expo) * mant_q
+ y = y.to(x_ptr.dtype.element_ty)
+
+ tl.store(output_ptr + offsets, y, mask=mask)
+
+
+@triton.autotune(
+ configs=[
+ # triton.Config({'BLOCK_SIZE': 4,}, num_warps=4),
+ triton.Config(
+ {
+ "BLOCK_SIZE": 1024,
+ },
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE": 2048,
+ },
+ num_warps=4,
+ ),
+ ],
+ key=["n_elements"],
+)
+@triton.jit
+def _floatExMy_stochastic_quantize_kernel(
+ x_ptr,
+ noise_ptr,
+ output_ptr,
+ n_elements,
+ e_bit,
+ m_bit,
+ BLOCK_SIZE: tl.constexpr,
+):
+ if isinstance(e_bit, tl.constexpr):
+ ebit = e_bit.value
+ else:
+ ebit = e_bit
+
+ if isinstance(m_bit, tl.constexpr):
+ mbit = m_bit.value
+ else:
+ mbit = m_bit
+
+ pid = tl.program_id(axis=0)
+ block_start = pid * BLOCK_SIZE
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < n_elements
+ x = tl.load(x_ptr + offsets, mask=mask)
+ noise = tl.load(noise_ptr + offsets, mask=mask)
+
+ x = x.to(tl.float32)
+ sign = 1 - 2 * libdevice.signbit(x)
+ x_abs = tl.abs(x)
+ Elow = -tl.exp2((ebit - 1).to(tl.float32)) + 2
+ Ehigh = tl.exp2((ebit - 1).to(tl.float32))
+ Mhigh = tl.exp2(mbit.to(tl.float32))
+ expo = tl.floor(tl.log2(x_abs))
+ expo = tl.clamp(expo, min=Elow, max=Ehigh)
+ mant = x_abs / tl.exp2(expo)
+
+ mant_int = tl.floor(mant)
+ mant_frac = mant - mant_int
+ mant_frac = mant_frac * Mhigh
+ mant_frac = mant_frac + noise
+ mant_frac = libdevice.round(mant_frac)
+
+ mant_q = mant_int + mant_frac / Mhigh
+ y = sign * tl.exp2(expo) * mant_q
+ y = y.to(x_ptr.dtype.element_ty)
+
+ tl.store(output_ptr + offsets, y, mask=mask)
diff --git a/llava/model/__init__.py b/llava/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e5494e136efbc55e4813a40184308c26fe5949b
--- /dev/null
+++ b/llava/model/__init__.py
@@ -0,0 +1,35 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from .language_model.llava_llama import LlavaLlamaConfig, LlavaLlamaModel
+
+# FP8 related comments, development in progress (PI: ligeng zhu, haochen xi)
+# NOTE: VLM + LLM
+# from .language_model.qllava_qllama import QLlavaLlamaConfig, QLlavaLlamaModel
+# NOTE: Linear -> fp8, similar to transformer engine
+# from .language_model.qllama import QLlamaConfig, QLlamaForCausalLM, QLlamaModel
+# NOTE: Linear + Activation -> fp8, haochen's iclr version
+# from .language_model.qmemllama import QMemLlamaConfig, QMemLlamaForCausalLM, QMemLlamaModel
+"""
+TODO:
+ linear(weights):
+ simulated fp8: done
+ real fp8: in-progress (code already implmented)
+ activation:
+ simulated fp8: done
+ real fp8: in-progress (still coding)
+ optimizers:
+ current VILA: bf16
+ simulated fp8: done
+ real fp8 + fsdp (single node): done
+ real fp8 + fsdp (multiple node): in-progress
+1. linear fp8
+2. activation fp8
+3. fp8 infernce example (load directly from a fp8 and fwd)
+4. bind fp8 related configs to QLlamaConfig {"coat_fp8_args": {}}
+"""
+from .language_model.fp8linearqwen2 import FP8LinearQwen2Config, FP8LinearQwen2Model
+from .language_model.qllava_qllama import QLlavaLlamaConfig, QLlavaLlamaModel
diff --git a/llava/model/apply_delta.py b/llava/model/apply_delta.py
new file mode 100644
index 0000000000000000000000000000000000000000..767d0612da9a1fd77058e3d004d16e45d572e3ca
--- /dev/null
+++ b/llava/model/apply_delta.py
@@ -0,0 +1,77 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+
+"""
+Usage:
+python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
+"""
+import argparse
+
+import torch
+from tqdm import tqdm
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from llava import LlavaLlamaForCausalLM
+
+
+def apply_delta(base_model_path, target_model_path, delta_path):
+ print("Loading base model")
+ base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+
+ print("Loading delta")
+ delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
+
+ print("Applying delta")
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
+ if name not in base.state_dict():
+ assert name in [
+ "model.mm_projector.weight",
+ "model.mm_projector.bias",
+ ], f"{name} not in base model"
+ continue
+ if param.data.shape == base.state_dict()[name].shape:
+ param.data += base.state_dict()[name]
+ else:
+ assert name in [
+ "model.embed_tokens.weight",
+ "lm_head.weight",
+ ], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
+ bparam = base.state_dict()[name]
+ param.data[: bparam.shape[0], : bparam.shape[1]] += bparam
+
+ print("Saving target model")
+ delta.save_pretrained(target_model_path)
+ delta_tokenizer.save_pretrained(target_model_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--base-model-path", type=str, required=True)
+ parser.add_argument("--target-model-path", type=str, required=True)
+ parser.add_argument("--delta-path", type=str, required=True)
+
+ args = parser.parse_args()
+
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
diff --git a/llava/model/builder.py b/llava/model/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c715ddf2eaeaa59d9807b6fdc3d31a71005067c
--- /dev/null
+++ b/llava/model/builder.py
@@ -0,0 +1,161 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import warnings
+
+import torch
+from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PretrainedConfig
+
+from llava.model import LlavaLlamaModel
+from llava.model.utils import is_mm_model
+
+
+def load_pretrained_model(
+ model_path,
+ model_name,
+ model_base=None,
+ load_8bit=False,
+ load_4bit=False,
+ device_map="auto",
+ device="cuda",
+ **kwargs,
+):
+ kwargs = {"device_map": device_map, **kwargs}
+
+ if device != "cuda":
+ kwargs["device_map"] = {"": device}
+
+ if load_8bit:
+ kwargs["load_in_8bit"] = True
+ elif load_4bit:
+ kwargs["load_in_4bit"] = True
+ kwargs["quantization_config"] = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_compute_dtype=torch.float16,
+ bnb_4bit_use_double_quant=True,
+ bnb_4bit_quant_type="nf4",
+ )
+ else:
+ kwargs["torch_dtype"] = torch.float16
+ # kwargs["torch_dtype"] = torch.bfloat16
+
+ if is_mm_model(model_path):
+ # Load LLaVA model
+ ## TODO @yunhao: mind fixing lora
+ if "lora" in model_name.lower() and model_base is None:
+ warnings.warn(
+ "There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged."
+ )
+ if ("lora" in model_name.lower() or "dora" in model_name.lower()) and model_base is not None:
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
+ print(lora_cfg_pretrained)
+ print("Loading LLaVA from base model...")
+ config = AutoConfig.from_pretrained(model_base)
+ prepare_config_for_eval(config, kwargs)
+ model = LlavaLlamaModel.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs)
+ tokenizer = model.tokenizer
+ token_num, tokem_dim = model.llm.lm_head.out_features, model.llm.lm_head.in_features
+ if model.llm.lm_head.weight.shape[0] != token_num:
+ model.llm.lm_head.weight = torch.nn.Parameter(
+ torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)
+ )
+ model.llm.embed_tokens.weight = torch.nn.Parameter(
+ torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)
+ )
+
+ print("Loading additional LLaVA weights...")
+ if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
+ non_lora_trainables = torch.load(
+ os.path.join(model_path, "non_lora_trainables.bin"),
+ map_location="cpu",
+ )
+ else:
+ # this is probably from HF Hub
+ from huggingface_hub import hf_hub_download
+
+ def load_from_hf(repo_id, filename, subfolder=None):
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
+ return torch.load(cache_file, map_location="cpu")
+
+ non_lora_trainables = load_from_hf(model_path, "non_lora_trainables.bin")
+ non_lora_trainables = {
+ (k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items()
+ }
+ if any(k.startswith("model.model.") for k in non_lora_trainables):
+ non_lora_trainables = {
+ (k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items()
+ }
+ model.load_state_dict(non_lora_trainables, strict=False)
+
+ from peft import PeftModel
+
+ print("Loading LoRA weights...")
+ model = PeftModel.from_pretrained(model, model_path)
+ print("Merging LoRA weights...")
+ model = model.merge_and_unload()
+ print("Model is loaded...")
+ else:
+ config = AutoConfig.from_pretrained(model_path)
+ config.resume_path = model_path
+ prepare_config_for_eval(config, kwargs)
+ model = LlavaLlamaModel(config=config, low_cpu_mem_usage=True, **kwargs)
+ tokenizer = model.tokenizer
+ else:
+ # Load language model
+ if model_base is not None:
+ # PEFT model
+ from peft import PeftModel
+
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
+ print(f"Loading LoRA weights from {model_path}")
+ model = PeftModel.from_pretrained(model, model_path)
+ print(f"Merging weights")
+ model = model.merge_and_unload()
+ print("Convert to FP16...")
+ model.to(torch.float16)
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, legacy=False)
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
+ model.eval()
+ image_processor = None
+ if is_mm_model(model_path):
+ model.resize_token_embeddings(len(tokenizer))
+
+ if hasattr(model.llm.config, "max_sequence_length"):
+ context_len = model.config.max_sequence_length
+ else:
+ context_len = 2048
+
+ return tokenizer, model, image_processor, context_len
+
+
+def prepare_config_for_eval(config: PretrainedConfig, kwargs: dict):
+ try:
+ # compatible with deprecated config convention
+ if getattr(config, "vision_tower_cfg", None) is None:
+ config.vision_tower_cfg = config.mm_vision_tower
+ except AttributeError:
+ raise ValueError(f"Invalid configuration! Cannot find vision_tower in config:\n{config}")
+
+ config.model_dtype = kwargs.pop("torch_dtype").__str__()
diff --git a/llava/model/coat/activation/__init__.py b/llava/model/coat/activation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..69923298db60fa8275bf22895fd5c54cf102e9b9
--- /dev/null
+++ b/llava/model/coat/activation/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
diff --git a/llava/model/coat/activation/fake_quantization/FloatPointQuantizeTorch.py b/llava/model/coat/activation/fake_quantization/FloatPointQuantizeTorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..04acd2fb73f707152839e1a97acb67e6f8df873b
--- /dev/null
+++ b/llava/model/coat/activation/fake_quantization/FloatPointQuantizeTorch.py
@@ -0,0 +1,101 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import math
+
+import torch
+
+
+def floatExMy_quantize_torch(x, e_bit, m_bit, stochastic):
+ sign, x_abs = x.sign(), x.abs()
+ Elow, Ehigh, Mhigh = -(2 ** (e_bit - 1)) + 2, 2 ** (e_bit - 1), 2**m_bit
+ expo = torch.floor(torch.log2(x_abs))
+ expo = torch.clamp(expo, min=Elow, max=Ehigh)
+ mant = x_abs / torch.exp2(expo)
+
+ mant_int = torch.floor(mant)
+ mant_frac = mant - mant_int
+ mant_frac = mant_frac * Mhigh
+ if stochastic:
+ noise = mant_frac.new(mant_frac.shape).uniform_(-0.5, 0.5)
+ mant_frac.add_(noise)
+ mant_frac = torch.round(mant_frac)
+
+ mant_q = mant_int + mant_frac / Mhigh
+ y = sign * (2**expo) * mant_q
+ y = y.to(x)
+
+ return y
+
+
+def floatExM0_quantize_torch(x, e_bit, stochastic):
+ sign, x_abs = x.sign(), x.abs()
+ Elow, Ehigh = -(2 ** (e_bit - 1)) + 1, 2 ** (e_bit - 1)
+ expo = torch.log2(x_abs)
+ if stochastic:
+ noise = expo.new(expo.shape).uniform_(-0.5, 0.5)
+ expo.add(noise)
+ log_bias = math.log2(4 / 3) - 1 / 2
+ expo.add(torch.ones_like(expo) * log_bias)
+ expo = torch.clamp(expo, min=Elow - 1, max=Ehigh)
+ expo = torch.round(expo)
+
+ y = sign * (2**expo) * (expo > Elow) # When underflow, set the value to 0
+ y = y.to(x)
+
+ return y
+
+
+def Dynamic_quantize_torch(x, bit, stochastic):
+ if stochastic:
+ raise NotImplementedError("Dynamic Tree quantization does not support stochastic")
+ sign, x_abs = x.sign(), x.abs()
+ expo = torch.ceil(torch.log10(x_abs))
+ expo = torch.clamp(expo, min=2 - bit)
+ mant = (10 * x_abs / torch.pow(10, expo) - 1) / 9 # Range from 0 - 1
+
+ mant_frac = mant * 2 ** (bit - 2 - expo.abs())
+ mant_frac = torch.round(mant_frac)
+ mant_frac = mant_frac / (2 ** (bit - 2 - expo.abs())) * 9 + 1
+ y = sign * (10**expo) * mant_frac / 10
+
+ zero_mask = y.abs() > 1.01 * 10 ** (1 - bit)
+ y = y * zero_mask
+ y = y.to(x)
+ return y
+
+
+def ZeroDynamic_quantize_torch(x, bit, stochastic):
+ if stochastic:
+ raise NotImplementedError("Dynamic Tree quantization does not support stochastic")
+ sign, x_abs = x.sign(), x.abs()
+ expo = torch.ceil(torch.log10(x_abs))
+ expo = torch.clamp(expo, min=2 - bit)
+ mant = (10 * x_abs / torch.pow(10, expo) - 1) / 9 # Range from 0 - 1
+
+ mant_frac = mant * 2 ** (bit - 2 - expo.abs())
+ mant_frac = torch.round(mant_frac)
+ mant_frac = mant_frac / (2 ** (bit - 2 - expo.abs())) * 9 + 1
+ y = sign * (10**expo) * mant_frac / 10
+
+ y = y.to(x)
+ return y
diff --git a/llava/model/coat/activation/fake_quantization/FloatPointQuantizeTriton.py b/llava/model/coat/activation/fake_quantization/FloatPointQuantizeTriton.py
new file mode 100644
index 0000000000000000000000000000000000000000..db85e7fc42dc54539f54400ccff280402e126f0d
--- /dev/null
+++ b/llava/model/coat/activation/fake_quantization/FloatPointQuantizeTriton.py
@@ -0,0 +1,181 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import math
+import struct
+
+import numpy as np
+import torch
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+
+def floatExMy_quantize_triton(x, e_bit, m_bit, stochastic):
+ n_elements = x.numel()
+ grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
+ y = torch.zeros_like(x)
+
+ if x.dtype in [torch.bfloat16, torch.float32]:
+ if stochastic:
+ noise = x.new(x.shape).uniform_(-0.5, 0.5)
+ _floatExMy_stochastic_quantize_kernel[grid](x, noise, y, n_elements, e_bit, m_bit)
+ else:
+ _floatExMy_quantize_kernel[grid](x, y, n_elements, e_bit, m_bit)
+ else:
+ raise NotImplementedError(f"Other data format {x.dtype} for float quantization triton")
+
+ return y
+
+
+@triton.autotune(
+ configs=[
+ # triton.Config({'BLOCK_SIZE': 4,}, num_warps=4),
+ triton.Config(
+ {
+ "BLOCK_SIZE": 1024,
+ },
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE": 2048,
+ },
+ num_stages=1,
+ ),
+ ],
+ key=["n_elements"],
+)
+@triton.jit
+def _floatExMy_quantize_kernel(
+ x_ptr,
+ output_ptr,
+ n_elements,
+ e_bit,
+ m_bit,
+ BLOCK_SIZE: tl.constexpr,
+):
+ if isinstance(e_bit, tl.constexpr):
+ ebit = e_bit.value
+ else:
+ ebit = e_bit
+
+ if isinstance(m_bit, tl.constexpr):
+ mbit = m_bit.value
+ else:
+ mbit = m_bit
+
+ pid = tl.program_id(axis=0)
+ block_start = pid * BLOCK_SIZE
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < n_elements
+ x = tl.load(x_ptr + offsets, mask=mask)
+
+ x = x.to(tl.float32)
+ sign = 1 - 2 * libdevice.signbit(x)
+ x_abs = tl.abs(x)
+ Elow = -tl.exp2((ebit - 1).to(tl.float32)) + 2
+ Ehigh = tl.exp2((ebit - 1).to(tl.float32))
+ Mhigh = tl.exp2(mbit.to(tl.float32))
+ expo = tl.floor(tl.log2(x_abs))
+ expo = tl.clamp(expo, min=Elow, max=Ehigh)
+ mant = x_abs / tl.exp2(expo)
+
+ mant_int = tl.floor(mant)
+ mant_frac = mant - mant_int
+ mant_frac = mant_frac * Mhigh
+ # mant_frac = mant_frac + noise
+ mant_frac = libdevice.round(mant_frac)
+
+ mant_q = mant_int + mant_frac / Mhigh
+ y = sign * tl.exp2(expo) * mant_q
+ y = y.to(x_ptr.dtype.element_ty)
+
+ tl.store(output_ptr + offsets, y, mask=mask)
+
+
+@triton.autotune(
+ configs=[
+ # triton.Config({'BLOCK_SIZE': 4,}, num_warps=4),
+ triton.Config(
+ {
+ "BLOCK_SIZE": 1024,
+ },
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE": 2048,
+ },
+ num_stages=1,
+ ),
+ ],
+ key=["n_elements"],
+)
+@triton.jit
+def _floatExMy_stochastic_quantize_kernel(
+ x_ptr,
+ noise_ptr,
+ output_ptr,
+ n_elements,
+ e_bit,
+ m_bit,
+ BLOCK_SIZE: tl.constexpr,
+):
+ if isinstance(e_bit, tl.constexpr):
+ ebit = e_bit.value
+ else:
+ ebit = e_bit
+
+ if isinstance(m_bit, tl.constexpr):
+ mbit = m_bit.value
+ else:
+ mbit = m_bit
+
+ pid = tl.program_id(axis=0)
+ block_start = pid * BLOCK_SIZE
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < n_elements
+ x = tl.load(x_ptr + offsets, mask=mask)
+ noise = tl.load(noise_ptr + offsets, mask=mask)
+
+ x = x.to(tl.float32)
+ sign = 1 - 2 * libdevice.signbit(x)
+ x_abs = tl.abs(x)
+ Elow = -tl.exp2((ebit - 1).to(tl.float32)) + 2
+ Ehigh = tl.exp2((ebit - 1).to(tl.float32))
+ Mhigh = tl.exp2(mbit.to(tl.float32))
+ expo = tl.floor(tl.log2(x_abs))
+ expo = tl.clamp(expo, min=Elow, max=Ehigh)
+ mant = x_abs / tl.exp2(expo)
+
+ mant_int = tl.floor(mant)
+ mant_frac = mant - mant_int
+ mant_frac = mant_frac * Mhigh
+ mant_frac = mant_frac + noise
+ mant_frac = libdevice.round(mant_frac)
+
+ mant_q = mant_int + mant_frac / Mhigh
+ y = sign * tl.exp2(expo) * mant_q
+ y = y.to(x_ptr.dtype.element_ty)
+
+ tl.store(output_ptr + offsets, y, mask=mask)
diff --git a/llava/model/coat/activation/fake_quantization/quantize_function.py b/llava/model/coat/activation/fake_quantization/quantize_function.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b18d45603ecd3f6564a8dbac3810d8ce701fbb6
--- /dev/null
+++ b/llava/model/coat/activation/fake_quantization/quantize_function.py
@@ -0,0 +1,239 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import re
+
+import torch
+
+from .FloatPointQuantizeTorch import *
+from .FloatPointQuantizeTriton import *
+
+
+def block_cut(input, row_block, column_block, pad_block=False):
+ # print(input.shape)
+ original_shape = input.shape
+ # input tensor shape is M * N
+ if len(input.shape) > 2:
+ input = input.reshape(-1, input.shape[2])
+ elif len(input.shape) == 2:
+ pass
+ else:
+ raise ValueError(f"input shape {input.shape} does not match for block cut, {input}")
+ M, N = input.shape[0], input.shape[1]
+
+ if row_block == -1:
+ row_block = M
+ if column_block == -1:
+ column_block = N
+
+ if pad_block:
+ row_remainder, col_remainder = M % row_block, N % column_block
+ if row_remainder:
+ row_pad = row_block - row_remainder
+ else:
+ row_pad = 0
+ if col_remainder:
+ col_pad = column_block - col_remainder
+ else:
+ col_pad = 0
+
+ input = torch.nn.functional.pad(
+ input, (0, col_pad, 0, row_pad), "constant", 0
+ ) # refer to torch's doc to see why
+ M, N = input.shape[0], input.shape[1]
+ row_num, column_num = M // row_block, N // column_block
+ else:
+ row_num, column_num = M // row_block, N // column_block
+
+ assert row_num * row_block == M, f"{row_num}, {row_block}, {M}, {original_shape}"
+ assert column_num * column_block == N, f"{column_num}, {column_block}, {N}, {original_shape}"
+ # print(input.shape)
+ input = (
+ input.reshape(row_num, row_block, column_num, column_block)
+ .permute(0, 2, 1, 3)
+ .reshape(row_num * column_num, row_block, column_block)
+ )
+ # print(input.shape)
+ return input
+
+
+def block_reshape(input, origin_input, row_block, column_block, pad_block=False):
+ if len(origin_input.shape) > 2:
+ flatten_input = origin_input.reshape(-1, origin_input.shape[2])
+ elif len(origin_input.shape) == 2:
+ flatten_input = origin_input
+ else:
+ raise ValueError(f"input shape {input.shape} does not match for block cut")
+
+ M, N = flatten_input.shape[0], flatten_input.shape[1]
+
+ if row_block == -1:
+ row_block = M
+ if column_block == -1:
+ column_block = N
+
+ if pad_block:
+ row_remainder, col_remainder = M % row_block, N % column_block
+ if row_remainder:
+ row_pad = row_block - row_remainder
+ else:
+ row_pad = 0
+ if col_remainder:
+ col_pad = column_block - col_remainder
+ else:
+ col_pad = 0
+
+ pad_origin_input = torch.nn.functional.pad(origin_input, (0, col_pad, 0, row_pad), "constant", 0)
+ M, N = pad_origin_input.shape[0], pad_origin_input.shape[1]
+ row_num, column_num = M // row_block, N // column_block
+ else:
+ row_num, column_num = M // row_block, N // column_block
+
+ input = (
+ input.reshape(row_num, column_num, row_block, column_block)
+ .permute(0, 2, 1, 3)
+ .reshape(row_num * row_block, column_num * column_block)
+ )
+
+ M, N = flatten_input.shape[0], flatten_input.shape[1]
+ input = input[:M, :N]
+
+ if len(origin_input.shape) > 2:
+ input = input.reshape(origin_input.shape)
+ elif len(origin_input.shape) == 2:
+ pass
+ else:
+ raise ValueError(f"input shape {input.shape} does not match for block reshape")
+
+ return input
+
+
+def block_verify_int8(input, row_block, column_block, layer_type, necessary=True):
+ Binput = block_cut(input, row_block, column_block)
+ Binput = Binput.to(torch.float32)
+
+ for n in range(Binput.shape[0]):
+ unique_values = len(torch.unique(Binput[n, :, :]))
+ if unique_values > 256:
+ if necessary:
+ raise ValueError(f"{layer_type} contains more than 256 unique values.")
+ else:
+ return False
+ return True
+
+
+def block_quant(input, symm, bits, stochastic, epsilon, apply_quantize, layer_name):
+ Quant_fn = SymmQuantizer
+ return Quant_fn.apply(input, symm, bits, stochastic, epsilon, apply_quantize, layer_name)
+
+
+def extract_bit(string):
+ match = re.match(r"INT(\d+)", string) # INT8
+ if match:
+ return "integer", int(match.group(1)), None
+ match = re.match(r"E(\d+)M(\d+)", string) # E4M3 / E5M2
+ if match:
+ Ebit, Mbit = int(match.group(1)), int(match.group(2))
+ if Ebit == 1:
+ return "integer", Mbit + 1, None
+ if Mbit == 0:
+ return "floatExM0", int(match.group(1)), 0
+ return "floatExMy", int(match.group(1)), int(match.group(2))
+ match = re.match(r"DE(\d+)", string)
+ if match:
+ return "Dynamic", int(match.group(1)), None
+ match = re.match(r"ZeroD(\d+)", string)
+ if match:
+ return "ZeroDynamic", int(match.group(1)), None
+ raise ValueError(f"{string} data format is not supported")
+
+
+class SymmQuantizer(torch.autograd.function.InplaceFunction):
+ @staticmethod
+ def forward(ctx, input, symm, bits, stochastic, epsilon, apply_quantize=True, layer_name=None):
+ with torch.no_grad():
+ absmax_per_block = input.abs().amax(dim=(1, 2)).unsqueeze(1).unsqueeze(2) + epsilon
+
+ if bits == "100" or not apply_quantize:
+ return input, input, torch.ones_like(absmax_per_block)
+ elif bits == "FP32":
+ return input.to(torch.float32), input.to(torch.float32), torch.ones_like(absmax_per_block)
+ elif bits == "FP16":
+ return input.to(torch.float16), input.to(torch.float16), torch.ones_like(absmax_per_block)
+ elif bits == "BF16":
+ return input.to(torch.bfloat16), input.to(torch.bfloat16), torch.ones_like(absmax_per_block)
+ else:
+ QuantType, bit1, bit2 = extract_bit(bits)
+ if not symm:
+ bit1 = bit1 + 1 # pretend to be asymmtric
+
+ if QuantType == "integer":
+ Qn, Qp = -(2 ** (bit1 - 1) - 1), 2 ** (bit1 - 1) - 1
+ elif QuantType == "floatExMy":
+ Qn, Qp = -(2 - 2 ** (-bit2)) * (2 ** (2 ** (bit1 - 1))), (2 - 2 ** (-bit2)) * (
+ 2 ** (2 ** (bit1 - 1))
+ )
+ if bit1 == 4 and bit2 == 3: # E4M3
+ Qn, Qp = -448, 448
+ if bit1 == 5 and bit2 == 2: # E5M2
+ Qn, Qp = -57344, 57344
+ elif QuantType == "floatExM0":
+ Qn, Qp = -(2 ** (2 ** (bit1 - 1))) + 1, 2 ** (2 ** (bit1 - 1))
+ elif QuantType == "Dynamic":
+ Qn, Qp = -1, 1
+ elif QuantType == "ZeroDynamic":
+ Qn, Qp = -1, 1
+ else:
+ raise NotImplementedError(f"{bits} is not supported by quantization")
+ scale_per_block = (2 * absmax_per_block) / (Qp - Qn)
+ scale_per_block = scale_per_block.to(input)
+
+ Qinput = input / scale_per_block
+
+ if QuantType == "integer":
+ if stochastic:
+ noise = Qinput.new(Qinput.shape).uniform_(-0.5, 0.5)
+ Qinput.add_(noise)
+ Qinput.clamp_(Qn, Qp).round_()
+ elif QuantType == "floatExMy":
+ # Qinput = floatExMy_quantize_torch(Qinput, bit1, bit2, stochastic)
+ Qinput = floatExMy_quantize_triton(Qinput, bit1, bit2, stochastic)
+ elif QuantType == "floatExM0":
+ Qinput = floatExM0_quantize_torch(Qinput, bit1, stochastic)
+ else:
+ raise NotImplementedError(f"{bits} is not supported by quantization")
+
+ RQinput = Qinput * scale_per_block
+
+ if input.dtype != Qinput.dtype:
+ print(
+ f"Input type is {input.dtype}, Qinput type is {Qinput.dtype}, scale_per_block type is {scale_per_block.dtype}",
+ file=open("debug.txt", "a"),
+ )
+ import IPython
+
+ IPython.embed()
+ return RQinput, Qinput, scale_per_block
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return grad_output, None, None, None, None, None
diff --git a/llava/model/coat/activation/fake_quantization/utils.py b/llava/model/coat/activation/fake_quantization/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc4ec2dae1ef7be1dc7c82c79439b7091d49feb6
--- /dev/null
+++ b/llava/model/coat/activation/fake_quantization/utils.py
@@ -0,0 +1,115 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import os
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+
+def list_has_common_element(list1, list2):
+ set1 = set(list1)
+ set2 = set(list2)
+ return len(set1.intersection(set2)) > 0
+
+
+def calculate_scale_num(input, row_block, col_block):
+ if len(input.shape) > 2:
+ input = input.reshape(-1, input.shape[2])
+ elif len(input.shape) == 2:
+ pass
+ else:
+ raise ValueError(f"input shape {input.shape} does not match for block cut, {input}")
+ M, N = input.shape[0], input.shape[1]
+
+ if row_block == -1:
+ row_block = M
+ if col_block == -1:
+ col_block = N
+
+ return input.numel() / (row_block * col_block)
+
+
+def quant_get_local_rank() -> int:
+ return int(os.environ.get("LOCAL_RANK") or 0)
+
+
+def format_string_with_condition(
+ input_string,
+ condition_config,
+ symm,
+ bits,
+ blocksize_config,
+ input_pad=20,
+):
+ padded_string = input_string.ljust(input_pad)
+ output_string = padded_string
+
+ for k, v in condition_config.items():
+ if v:
+ output_string = output_string + k.ljust(10) + "True".ljust(6) + "".ljust(6)
+ else:
+ output_string = output_string + k.ljust(10) + "".ljust(6) + "False".ljust(6)
+
+ output_string = output_string + f"Symm {symm}".ljust(10)
+
+ for k, v in bits.items():
+ output_string = output_string + f"{k} bit".ljust(10) + v.ljust(10)
+ for k, v in blocksize_config.items():
+ output_string += f"{k}: {v}".ljust(15)
+
+ return output_string
+
+
+def print_warning(sentence):
+ print("*" * (len(sentence) + 4))
+ print(f"* {sentence} *")
+ print("*" * (len(sentence) + 4))
+
+
+def check_nan_inf(tensor, check_nan, check_inf):
+ if check_nan:
+ contain_nan = torch.isnan(tensor).any()
+ else:
+ contain_nan = False
+ if check_inf:
+ contain_inf = torch.isinf(tensor).any()
+ else:
+ contain_inf = False
+ return contain_nan, contain_inf
+
+
+def move_torch_to_numpy(tensor):
+ if tensor is None:
+ return None
+
+ if tensor.is_cuda:
+ tensor = tensor.cpu()
+ return tensor.detach().float().numpy()
+
+
+def flatten_to_1d(tensor):
+ if tensor is None:
+ return None
+
+ return tensor.reshape(-1)
diff --git a/llava/model/coat/activation/models/_fp8_quantization_config.py b/llava/model/coat/activation/models/_fp8_quantization_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..579df908203575abf8da9565995640c4ae459e3b
--- /dev/null
+++ b/llava/model/coat/activation/models/_fp8_quantization_config.py
@@ -0,0 +1,67 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from dataclasses import dataclass
+
+from transformers import PretrainedConfig
+
+
+@dataclass
+class QuantizationConfig:
+ quantize_model: str = "false"
+ symm: bool = True
+ epsilon: float = 1e-10
+ fabit: str = "E4M3"
+ fwbit: str = "E4M3"
+ fobit: str = "E4M3"
+ babit: str = "E5M2"
+ bwbit: str = "E5M2"
+ bobit: str = "E5M2"
+ qchoice: str = "none"
+ group_size: int = -1
+ pad_to_multiple_of: int = 0
+ weight_memory_efficient: bool = True
+
+ # Legacy
+ row_blocksize: int = -1
+ col_blocksize: int = -1
+
+ def __init__(
+ self,
+ quantize_model: str = "false",
+ symm: bool = True,
+ epsilon: float = 1e-10,
+ fabit: str = "E4M3",
+ fwbit: str = "E4M3",
+ fobit: str = "E4M3",
+ babit: str = "E5M2",
+ bwbit: str = "E5M2",
+ bobit: str = "E5M2",
+ qchoice: str = "none",
+ group_size: int = -1,
+ pad_to_multiple_of: int = 0,
+ weight_memory_efficient: bool = True,
+ row_blocksize: int = -1,
+ col_blocksize: int = -1,
+ **kwargs,
+ ):
+ super().__init__()
+ self.quantize_model = quantize_model
+ self.symm = symm
+ self.epsilon = epsilon
+ self.fabit = fabit
+ self.fwbit = fwbit
+ self.fobit = fobit
+ self.babit = babit
+ self.bwbit = bwbit
+ self.bobit = bobit
+ self.qchoice = qchoice
+ self.group_size = group_size
+ self.pad_to_multiple_of = pad_to_multiple_of
+ self.weight_memory_efficient = weight_memory_efficient
+
+ self.row_blocksize = row_blocksize
+ self.col_blocksize = col_blocksize
diff --git a/llava/model/coat/activation/models/_fp8_weightcache.py b/llava/model/coat/activation/models/_fp8_weightcache.py
new file mode 100644
index 0000000000000000000000000000000000000000..85b1c504c74b83299181020e1275133428f93e1d
--- /dev/null
+++ b/llava/model/coat/activation/models/_fp8_weightcache.py
@@ -0,0 +1,48 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import torch.nn as nn
+
+from ..real_quantization import fp8_division_transpose
+
+
+class FP8CacheWeightModule(nn.Module):
+ def __init__(self, config, qargs, layer_id):
+ super().__init__()
+ self.config = config
+ self.qargs = qargs
+ self.layer_id = layer_id
+
+ def prepare_weight(self, weight, weight_name, is_first_microbatch):
+ if is_first_microbatch:
+ if self.qargs.weight_memory_efficient:
+ # print(f"{weight_name} uses first microbatch")
+ weight_fp8, weight_s, weight_fp8_t = fp8_division_transpose(
+ weight, self.qargs.group_size, self.fwobits["fwbit"]
+ )
+ setattr(self, f"{weight_name}_fp8_scale", weight_s)
+ return weight_fp8, weight_fp8_t, weight_s
+ else:
+ # print(f"{weight_name} uses first microbatch")
+ weight_fp8, weight_s, weight_fp8_t = fp8_division_transpose(
+ weight, self.qargs.group_size, self.fwobits["fwbit"]
+ )
+ setattr(self, f"{weight_name}_fp8", weight_fp8)
+ setattr(self, f"{weight_name}_fp8_t", weight_fp8_t)
+ setattr(self, f"{weight_name}_fp8_scale", weight_s)
+ return weight_fp8, weight_fp8_t, weight_s
+ else:
+ if self.qargs.weight_memory_efficient:
+ return getattr(self, f"{weight_name}_fp8_scale")
+ else:
+ return (
+ getattr(self, f"{weight_name}_fp8"),
+ getattr(self, f"{weight_name}_fp8_t"),
+ getattr(self, f"{weight_name}_fp8_scale"),
+ )
+
+ def forward(self, x):
+ pass
diff --git a/llava/model/coat/activation/models/_fp8manager.py b/llava/model/coat/activation/models/_fp8manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..5270acf79adc7d17838959b0c679cba3b43c5bee
--- /dev/null
+++ b/llava/model/coat/activation/models/_fp8manager.py
@@ -0,0 +1,31 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+
+class FP8Manager:
+ """Class to keep track of and manipulate the global
+ FP8 state at different stages of execution.
+ """
+
+ is_first_microbatch = False
diff --git a/llava/model/coat/activation/models/coat_llama.py b/llava/model/coat/activation/models/coat_llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b6763c027ac05d04c1b66b483cf886c0e1f7fab
--- /dev/null
+++ b/llava/model/coat/activation/models/coat_llama.py
@@ -0,0 +1,1479 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+import os
+from fnmatch import fnmatch
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache, StaticCache
+from transformers.generation import GenerationMixin
+from transformers.modeling_attn_mask_utils import AttentionMaskConverter
+from transformers.modeling_flash_attention_utils import _flash_attention_forward
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.llama.configuration_llama import LlamaConfig
+from transformers.models.llama.modeling_llama import (
+ LlamaAttention,
+ LlamaDynamicNTKScalingRotaryEmbedding,
+ LlamaForCausalLM,
+ LlamaLinearScalingRotaryEmbedding,
+ LlamaModel,
+ LlamaPreTrainedModel,
+ LlamaRMSNorm,
+ LlamaRotaryEmbedding,
+ _prepare_4d_causal_attention_mask_with_cache_position,
+ apply_rotary_pos_emb,
+ repeat_kv,
+ rotate_half,
+)
+from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_greater_or_equal_2_10,
+ is_torchdynamo_compiling,
+ logging,
+ replace_return_docstrings,
+)
+
+from ..real_quantization import (
+ Coat_quantize_bgn,
+ Coat_quantize_end,
+ fp8_add_Ifp_Ifp_Ofp_Og16,
+ fp8_add_Ifp_Ifp_Ofp_Opt,
+ fp8_division,
+ fp8_division_transpose,
+ fp8_gelu_backward,
+ fp8_gelu_forward,
+ fp8_layernorm_noparam_backward,
+ fp8_layernorm_noparam_forward,
+ fp8_linear_backward,
+ fp8_linear_forward,
+ fp8_mul_backward,
+ fp8_mul_forward,
+ fp8_quantize,
+ fp8_quantize_pertensor,
+ fp8_quantize_pertensor_transpose,
+ fp8_rmsnorm_backward,
+ fp8_rmsnorm_forward,
+ fp8_silu_backward,
+ fp8_silu_forward,
+ fp8_transpose,
+)
+
+# FP8 related
+from ._fp8_quantization_config import QuantizationConfig
+from ._fp8_weightcache import FP8CacheWeightModule
+from ._fp8manager import FP8Manager
+
+logger = logging.get_logger(__name__)
+
+
+class CoatLlamaConfig(LlamaConfig):
+ model_type = "fp8_llama"
+
+
+class CoatLlamaBeforeAttentionResidual(FP8CacheWeightModule):
+ """
+ This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 1 * Linear layers
+ """
+
+ def __init__(self, config: CoatLlamaConfig, qargs: QuantizationConfig, layer_idx: Optional[int] = None):
+ super().__init__(config, qargs, layer_idx)
+
+ self.qargs = qargs
+ self.fwobits = {
+ "fabit": self.qargs.fabit,
+ "fwbit": self.qargs.fwbit,
+ "fobit": self.qargs.fobit,
+ "babit": self.qargs.babit,
+ "bwbit": self.qargs.bwbit,
+ "bobit": self.qargs.bobit,
+ }
+
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+
+ def forward(self, re_x, x, s, rmsnorm_weight):
+ if self.training:
+ if self.qargs.weight_memory_efficient:
+ # Prepare
+ with torch.no_grad():
+ weight1_s = self.prepare_weight(self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch)
+ weight2_s = self.prepare_weight(self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch)
+ weight3_s = self.prepare_weight(self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch)
+ return _CoatLlamaBeforeAttentionResidual.apply(
+ re_x,
+ x,
+ s,
+ self.q_proj.weight,
+ None,
+ None,
+ weight1_s,
+ self.k_proj.weight,
+ None,
+ None,
+ weight2_s,
+ self.v_proj.weight,
+ None,
+ None,
+ weight3_s,
+ rmsnorm_weight,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ # Prepare
+ with torch.no_grad():
+ weight1, weight1_t, weight1_s = self.prepare_weight(
+ self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch
+ )
+ weight2, weight2_t, weight2_s = self.prepare_weight(
+ self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch
+ )
+ weight3, weight3_t, weight3_s = self.prepare_weight(
+ self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch
+ )
+ return _CoatLlamaBeforeAttentionResidual.apply(
+ re_x,
+ x,
+ s,
+ self.q_proj.weight,
+ weight1,
+ weight1_t,
+ weight1_s,
+ self.k_proj.weight,
+ weight2,
+ weight2_t,
+ weight2_s,
+ self.v_proj.weight,
+ weight3,
+ weight3_t,
+ weight3_s,
+ rmsnorm_weight,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ return re_x, self.att_proj(self.attn_norm(re_x))
+
+
+class _CoatLlamaBeforeAttentionResidual(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ re_x,
+ in_x,
+ in_s,
+ weight1_origin,
+ weight1,
+ weight1_t,
+ weight1_s,
+ weight2_origin,
+ weight2,
+ weight2_t,
+ weight2_s,
+ weight3_origin,
+ weight3,
+ weight3_t,
+ weight3_s,
+ rmsnorm_weight,
+ group_size,
+ fwobits,
+ layer_id,
+ config,
+ qargs,
+ eps=1e-5,
+ ):
+ # for autograd
+ if fwobits["fabit"] == "E4M3":
+ # in_x = in_x.to(torch.float8_e4m3fn)
+ in_x = in_x.view(torch.float8_e4m3fn)
+ else:
+ raise ValueError("fabit should be E4M3")
+
+ # LayerNorm
+ ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward(
+ in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True
+ )
+
+ # Linear Layer QKV Projection
+ if qargs.weight_memory_efficient:
+ assert weight1 is None # memory efficient
+ weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s)
+ weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s)
+ weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s)
+
+ fc1_x = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, False, group_size) # query states
+ fc2_x = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, False, group_size) # key states
+ fc3_x = fp8_linear_forward(ln_x, ln_s, weight3, weight3_s, False, group_size) # value states
+
+ # ==================== save for backward ====================
+ ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s)
+ if qargs.weight_memory_efficient:
+ assert weight1_t is None and weight2_t is None and weight3_t is None
+ ctx.weight = weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s
+ else:
+ ctx.weight = weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s
+
+ ctx.group_size = group_size
+ ctx.ln_utils = ln_utils
+ ctx.utils = fwobits, layer_id, config, qargs
+
+ return re_x, fc1_x, fc2_x, fc3_x
+
+ @staticmethod
+ def backward(ctx, fp_grad, query_g, key_g, value_g):
+ in_x, in_s, ln_x_t, ln_s = ctx.saved_tensors
+ weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s = ctx.weight
+
+ group_size = ctx.group_size
+ rms_weight, rstd, num_warps = ctx.ln_utils
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ # ==================== Begin backward ====================
+ # Quantize the RoPE and FlashAttention Output. grad_input and grad_weight requires different data layout.
+ query_g, query_gs, query_g_t = fp8_quantize_pertensor_transpose(
+ query_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False
+ )
+ key_g, key_gs, key_g_t = fp8_quantize_pertensor_transpose(
+ key_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False
+ )
+ value_g, value_gs, value_g_t = fp8_quantize_pertensor_transpose(
+ value_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False
+ )
+
+ # Linear Layer QKV Projection
+ if qargs.weight_memory_efficient:
+ weight1_t, weight1_s = fp8_division_transpose(
+ weight1_t, qargs.group_size, fwobits["fwbit"], weight1_s, only_transposed=True
+ )
+ weight2_t, weight2_s = fp8_division_transpose(
+ weight2_t, qargs.group_size, fwobits["fwbit"], weight2_s, only_transposed=True
+ )
+ weight3_t, weight3_s = fp8_division_transpose(
+ weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True
+ )
+
+ fc1_g1, att_q_wg = fp8_linear_backward(
+ ln_x_t, ln_s, query_g, query_gs, query_g_t, weight1_t, weight1_s, group_size
+ )
+ fc1_g2, att_k_wg = fp8_linear_backward(ln_x_t, ln_s, key_g, key_gs, key_g_t, weight2_t, weight2_s, group_size)
+ fc1_g3, att_v_wg = fp8_linear_backward(
+ ln_x_t, ln_s, value_g, value_gs, value_g_t, weight3_t, weight3_s, group_size
+ )
+
+ fc1_g = fc1_g1 + fc1_g2 + fc1_g3
+
+ # LayerNorm
+ in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc1_g, rms_weight, rstd, group_size, num_warps)
+
+ # Add the gradient together, and prepare the input of the next layer.
+ re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt(
+ fp_grad, in_g, group_size, fwobits["babit"], stochastic=False
+ )
+
+ # for autograd. forward's data type should be the same of backward tensor. this will not change the actual binary representation.
+ in_g = in_g.view(torch.float8_e4m3fn)
+
+ # Although the next operator is a linear layer in MLPResidual module, we return in_sg_g16 to make the size compatible with the forward. Otherwise it will not pass autograd.
+ return (
+ re_g,
+ in_g,
+ in_sg_g16,
+ att_q_wg,
+ None,
+ None,
+ None,
+ att_k_wg,
+ None,
+ None,
+ None,
+ att_v_wg,
+ None,
+ None,
+ None,
+ rms_weight_grad,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+class CoatLlamaAfterAttentionResidual(FP8CacheWeightModule):
+ """
+ This is a typical transformer attention module that contains (1) Residual (2) 1 * Linear layers
+ """
+
+ def __init__(self, config: CoatLlamaConfig, qargs: QuantizationConfig, layer_id):
+ super().__init__(config, qargs, layer_id)
+
+ self.qargs = qargs
+ self.fwobits = {
+ "fabit": self.qargs.fabit,
+ "fwbit": self.qargs.fwbit,
+ "fobit": self.qargs.fobit,
+ "babit": self.qargs.babit,
+ "bwbit": self.qargs.bwbit,
+ "bobit": self.qargs.bobit,
+ }
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
+
+ def forward(self, re_x, in_x):
+ if self.training:
+ if self.qargs.weight_memory_efficient:
+ # prepare for the weight
+ with torch.no_grad():
+ weight4_s = self.prepare_weight(self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch)
+
+ return _CoatLlamaAfterAttentionResidual.apply(
+ re_x,
+ in_x,
+ self.o_proj.weight,
+ None,
+ None,
+ weight4_s,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ # prepare for the weight
+ with torch.no_grad():
+ weight4, weight4_t, weight4_s = self.prepare_weight(
+ self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch
+ )
+
+ return _CoatLlamaAfterAttentionResidual.apply(
+ re_x,
+ in_x,
+ self.o_proj.weight,
+ weight4,
+ weight4_t,
+ weight4_s,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ return re_x + self.attn_out(in_x), None, None
+
+
+class _CoatLlamaAfterAttentionResidual(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx, re_x, flash_x, weight4_origin, weight4, weight4_t, weight4_s, group_size, fwobits, layer_id, config, qargs
+ ):
+ # Quantize the FlashAttention Output
+ flash_qx, flash_s, _ = fp8_quantize_pertensor(
+ flash_x, group_size, fwobits["fabit"]
+ ) # Modified to make it memory efficient
+
+ # # Attention Projection Linear Layer
+ if qargs.weight_memory_efficient:
+ assert weight4 is None # memory efficient
+ weight4, weight4_s = fp8_division(weight4_origin, qargs.group_size, fwobits["fwbit"], weight4_s)
+ fc4_x = fp8_linear_forward(flash_qx, flash_s, weight4, weight4_s, False, group_size) #
+
+ # import IPython
+ # IPython.embed()
+ # Add the activations together
+ fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc4_x, flash_qx.dtype, group_size)
+
+ # ==================== save for backward ====================
+ ctx.save_for_backward(flash_x, flash_s)
+ if qargs.weight_memory_efficient:
+ assert weight4_t is None
+ ctx.weight = weight4_origin, weight4_s
+ else:
+ ctx.weight = weight4_t, weight4_s
+ ctx.group_size = group_size
+ ctx.fwobits = fwobits
+ ctx.utils = fwobits, layer_id, config, qargs
+
+ # For autograd
+ out_x = out_x.view(torch.float8_e4m3fn)
+
+ return fp_x, out_x, out_s
+
+ @staticmethod
+ def backward(ctx, fp_grad, out_g, out_gs):
+ flash_x, flash_s = ctx.saved_tensors
+ weight4_t, weight4_s = ctx.weight
+ group_size = ctx.group_size
+ fwobits = ctx.fwobits
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ # for autograd
+ if fwobits["babit"] == "E5M2":
+ # out_g = out_g.to(torch.float8_e5m2)
+ out_g = out_g.view(torch.float8_e5m2)
+ else:
+ raise ValueError("babit should be E5M2")
+ out_gs_max = out_gs.max()
+
+ # ==================== Begin backward ====================
+ # Output Projection
+ out_g_t = fp8_transpose(out_g, transpose_output_2d=True)
+
+ # We do not save an extra flash_x to save the memory usage
+ flash_x_t, flash_s = fp8_division_transpose(
+ flash_x, group_size, fwobits["fabit"], flash_s, stochastic=False, only_transposed=True
+ )
+
+ if qargs.weight_memory_efficient:
+ weight4_t, weight4_s = fp8_division_transpose(
+ weight4_t, qargs.group_size, fwobits["fwbit"], weight4_s, only_transposed=True
+ )
+ fc4_g, attn_out_wg = fp8_linear_backward(
+ flash_x_t, flash_s, out_g, out_gs_max, out_g_t, weight4_t, weight4_s, group_size
+ )
+
+ return fp_grad, fc4_g, attn_out_wg, None, None, None, None, None, None, None, None
+
+
+class CoatLlamaMLPResidual(FP8CacheWeightModule):
+ """
+ This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 2 / 3 * Linear layers
+ (4) GELU / Silu Activation
+ """
+
+ def __init__(self, config: CoatLlamaConfig, qargs: QuantizationConfig, layer_id, hidden_size: int):
+ super().__init__(config, qargs, layer_id)
+
+ self.qargs = qargs
+ self.fwobits = {
+ "fabit": self.qargs.fabit,
+ "fwbit": self.qargs.fwbit,
+ "fobit": self.qargs.fobit,
+ "babit": self.qargs.babit,
+ "bwbit": self.qargs.bwbit,
+ "bobit": self.qargs.bobit,
+ }
+
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
+ self.training = True
+
+ # below is only used when training = False
+ assert config.hidden_act == "silu", "We only support silu activation currently"
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, re_x, x, s, rmsnorm_weight):
+ if self.training:
+ if self.qargs.weight_memory_efficient: # prepare for the weight
+ with torch.no_grad():
+ weight1_s = self.prepare_weight(self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch)
+ weight2_s = self.prepare_weight(self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch)
+ weight3_s = self.prepare_weight(self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch)
+
+ return _CoatLlamaMLPResidual.apply(
+ re_x,
+ x,
+ s,
+ self.gate_proj.weight,
+ None,
+ None,
+ weight1_s,
+ self.up_proj.weight,
+ None,
+ None,
+ weight2_s,
+ self.down_proj.weight,
+ None,
+ None,
+ weight3_s,
+ rmsnorm_weight,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ # prepare for the weight
+ with torch.no_grad():
+ weight1, weight1_t, weight1_s = self.prepare_weight(
+ self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch
+ )
+ weight2, weight2_t, weight2_s = self.prepare_weight(
+ self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch
+ )
+ weight3, weight3_t, weight3_s = self.prepare_weight(
+ self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch
+ )
+
+ return _CoatLlamaMLPResidual.apply(
+ re_x,
+ x,
+ s,
+ self.gate_proj.weight,
+ weight1,
+ weight1_t,
+ weight1_s,
+ self.up_proj.weight,
+ weight2,
+ weight2_t,
+ weight2_s,
+ self.down_proj.weight,
+ weight3,
+ weight3_t,
+ weight3_s,
+ rmsnorm_weight,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ raise NotImplementedError("Need TODO")
+ og_x = re_x
+ re_x = self.ff_norm(re_x)
+ re_x = self.ff_proj(re_x)
+ re_x = self.act(re_x)
+ re_x = self.ff_out(re_x)
+ re_x = og_x + re_x
+ return re_x, None, None
+
+
+class _CoatLlamaMLPResidual(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ re_x,
+ in_x,
+ in_s,
+ weight1_origin,
+ weight1,
+ weight1_t,
+ weight1_s,
+ weight2_origin,
+ weight2,
+ weight2_t,
+ weight2_s,
+ weight3_origin,
+ weight3,
+ weight3_t,
+ weight3_s,
+ rmsnorm_weight,
+ group_size,
+ fwobits,
+ layer_id,
+ config,
+ qargs,
+ eps=1e-5,
+ ):
+ # For autograd
+ if fwobits["fabit"] == "E4M3":
+ # in_x = in_x.to(torch.float8_e4m3fn)
+ in_x = in_x.view(torch.float8_e4m3fn)
+ else:
+ raise ValueError("fabit should be E4M3")
+
+ # LayerNorm
+ ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward(
+ in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True
+ )
+
+ # Linear Layer of Up Projection and Gate Projection. They are fused as one linear layer.
+ if qargs.weight_memory_efficient:
+ assert weight1 is None and weight2 is None and weight3 is None # memory efficient
+ weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s)
+ weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s)
+ weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s)
+
+ gate_x, gate_s = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, True, group_size) # Gate Proj
+ up_x, up_s = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, True, group_size) # Up Proj
+
+ # silu Activation
+ silu_x, silu_s = fp8_silu_forward(gate_x, gate_s, group_size)
+
+ # Element-wise Multiplication
+ mul_x, mul_s, mul_x_t = fp8_mul_forward(silu_x, silu_s, up_x, up_s, group_size, transpose_output_2d=True)
+
+ # Output Projection
+ if weight3 is None: # memory efficient
+ weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s)
+ fc3_x = fp8_linear_forward(mul_x, mul_s, weight3, weight3_s, False, group_size)
+
+ # Add the activation together
+ fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc3_x, mul_x.dtype, group_size)
+
+ # ==================== save for backward ====================
+ ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s)
+
+ ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s)
+ if (
+ qargs.weight_memory_efficient
+ ): # Weight_1/2_origin will not be saved twice, so it will be more memory efficient.
+ assert weight1_t is None and weight2_t is None and weight3_t is None
+ ctx.weight = (weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s)
+ else: # Weight1/2_t is different from the origin weight, so saving it will consumes additional memory footprint.
+ ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s)
+
+ ctx.group_size = group_size
+ ctx.ln_utils = ln_utils
+ ctx.utils = fwobits, layer_id, config, qargs
+
+ out_x = out_x.view(torch.float8_e4m3fn)
+
+ return fp_x, out_x, out_s
+
+ @staticmethod
+ def backward(ctx, fp_grad, out_g, out_gs):
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s = ctx.saved_tensors
+
+ (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s) = ctx.weight
+ group_size = ctx.group_size
+ rms_weight, rstd, num_warps = ctx.ln_utils
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ # For autograd
+ if fwobits["babit"] == "E5M2":
+ # out_g = out_g.to(torch.float8_e5m2)
+ out_g = out_g.view(torch.float8_e5m2)
+ else:
+ raise ValueError("babit should be E5M2")
+ out_gs_max = out_gs.max()
+
+ # ==================== Begin backward ====================
+ # Output Projection
+ out_gs = out_gs.max()
+ out_g_t = fp8_transpose(out_g, transpose_output_2d=True)
+
+ if qargs.weight_memory_efficient:
+ weight3_t, weight3_s = fp8_division_transpose(
+ weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True
+ )
+ fc3_g, weight3_grad = fp8_linear_backward(
+ mul_x_t, mul_s, out_g, out_gs_max, out_g_t, weight3_t, weight3_s, group_size
+ )
+
+ # [MEM TEST]
+ del out_g, out_g_t, weight3_t
+
+ # Element-wise Multiplication, 1 means gate, 2 means up
+ mul_g1, (mul_g2, mul_gs2, mul_g2_t) = fp8_mul_backward(
+ silu_x, silu_s, up_x, up_s, fc3_g, group_size, fwobits["babit"], output_quantized_transpose=True
+ )
+
+ # Silu activation
+ silu_g, silu_gs, silu_g_t = fp8_silu_backward(
+ gate_x, gate_s, mul_g1, group_size, fwobits["babit"], output_quantized_transpose=True
+ )
+
+ # Linear Layer of Up and Gate Projection
+ if qargs.weight_memory_efficient:
+ weight1_t, weight1_s = fp8_division_transpose(
+ weight1_t, group_size, fwobits["fwbit"], weight1_s, only_transposed=True
+ )
+ weight2_t, weight2_s = fp8_division_transpose(
+ weight2_t, group_size, fwobits["fwbit"], weight2_s, only_transposed=True
+ )
+
+ # Gate Proj
+ fc1_g, weight1_grad = fp8_linear_backward(
+ ln_x_t, ln_s, silu_g, silu_gs, silu_g_t, weight1_t, weight1_s, group_size
+ )
+ fc2_g, weight2_grad = fp8_linear_backward(
+ ln_x_t, ln_s, mul_g2, mul_gs2, mul_g2_t, weight2_t, weight2_s, group_size
+ )
+
+ fc_g = fc1_g + fc2_g
+
+ # layerNorm
+ in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc_g, rms_weight, rstd, group_size, num_warps)
+
+ # Add the gradient together
+ re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt(
+ fp_grad, in_g, group_size, fwobits["babit"], stochastic=False
+ )
+
+ in_g = in_g.view(torch.float8_e4m3fn)
+
+ return (
+ re_g,
+ in_g,
+ in_sg_g16,
+ weight1_grad,
+ None,
+ None,
+ None,
+ weight2_grad,
+ None,
+ None,
+ None,
+ weight3_grad,
+ None,
+ None,
+ None,
+ rms_weight_grad,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+class LlamaAttentionWithoutLinear(nn.Module):
+ """
+ Remove the Q/K/V/O projection layer in LlamaAttention module and only calculate the attention logic.
+ The Q/K/V Projection is moved to BeforeAttention Module, and the O Projection is moved to AfterAttention Module.
+ """
+
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.is_causal = True
+
+ # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
+ self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
+
+ def forward(
+ self,
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = query_states.size()
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class LlamaFlashAttention2WithoutLinear(LlamaAttentionWithoutLinear):
+ """
+ Remove the Q/K/V/O projection layer in LlamaFlashAttention2 module and only calculate the attention logic.
+ The Q/K/V Projection is moved to BeforeAttention Module, and the O Projection is moved to AfterAttention Module.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ output_attentions = False
+
+ bsz, q_len, _ = query_states.size()
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=getattr(self, "sliding_window", None),
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class LlamaSdpaAttentionWithoutLinear(LlamaAttentionWithoutLinear):
+ """
+ Remove the Q/K/V/O projection layer in LlamaSdpaAttention module and only calculate the attention logic.
+ The Q/K/V Projection is moved to BeforeAttention Module, and the O Projection is moved to AfterAttention Module.
+ """
+
+ # Adapted from LlamaAttention.forward
+ def forward(
+ self,
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ query_states=query_states,
+ key_states=key_states,
+ value_states=value_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ bsz, q_len, _ = query_states.size()
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, -1)
+
+ return attn_output, None, past_key_value
+
+
+COAT_LLAMA_ATTENTION_CLASSES = {
+ "eager": LlamaAttentionWithoutLinear,
+ "flash_attention_2": LlamaFlashAttention2WithoutLinear,
+ "sdpa": LlamaSdpaAttentionWithoutLinear,
+}
+
+
+class CoatLlamaDecoderLayer(nn.Module):
+ def __init__(self, config: CoatLlamaConfig, layer_idx: int):
+ super().__init__()
+ self.layer_idx = layer_idx
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = COAT_LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
+
+ self.qargs = QuantizationConfig(**config.coat_fp8_args)
+ self.BeforeAttention = CoatLlamaBeforeAttentionResidual(config, self.qargs, layer_idx)
+ self.AfterAttention = CoatLlamaAfterAttentionResidual(config, self.qargs, layer_idx)
+ self.MLPResidual = CoatLlamaMLPResidual(config, self.qargs, layer_idx, self.hidden_size)
+
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ quant_hidden_states: torch.Tensor,
+ scale_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): BF16 input to the layer of shape `(batch, seq_len, embed_dim)`
+ quant_hidden_states (`torch.float8_e4m3fn`): FP8 input to the layer of shape `(batch, seq_len, embed_dim)`
+ scale_hidden_states (`torch.bfloat16`): BF16 scaling factor to the layer of shape `(batch, seq_len, embed_dim // group_size)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+
+ # Coat: The residual, LayerNorm, and the Q/K/V Projection Linear Layer
+ residual, query_states, key_states, value_states = self.BeforeAttention(
+ hidden_states, quant_hidden_states, scale_hidden_states, self.input_layernorm.weight
+ )
+
+ # Self Attention without any linear layer
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ query_states=query_states,
+ key_states=key_states,
+ value_states=value_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ # Coat: The Output Projection Linear Layer and Residual
+ hidden_states, quant_hidden_states, scale_hidden_states = self.AfterAttention(residual, hidden_states)
+
+ # Residual Connection, LayerNorm, and the whole MLP module
+ hidden_states, quant_hidden_states, scale_hidden_states = self.MLPResidual(
+ hidden_states, quant_hidden_states, scale_hidden_states, self.post_attention_layernorm.weight
+ )
+
+ outputs = ((hidden_states, quant_hidden_states, scale_hidden_states),)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class CoatLlamaPreTrainedModel(PreTrainedModel):
+ config_class = CoatLlamaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["LlamaDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_cache_class = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class CoatLlamaModel(CoatLlamaPreTrainedModel):
+ """
+ Coat Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`CoatLlamaDecoderLayer`]
+
+ Args:
+ config: CoatLlamaConfig
+ """
+
+ def __init__(self, config: CoatLlamaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [CoatLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Quantize
+ self.qargs = QuantizationConfig(**config.coat_fp8_args)
+ self.quantize_input_before_block = Coat_quantize_bgn(self.qargs)
+ self.quantize_output_after_block = Coat_quantize_end(self.qargs)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ # Prepare the input for Coat decoderlayer
+ hidden_states, quant_hidden_states, scale_hidden_states = self.quantize_input_before_block(hidden_states)
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ quant_hidden_states,
+ scale_hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ position_embeddings,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ quant_hidden_states,
+ scale_hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ hidden_states, quant_hidden_states, scale_hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ # Summarize the output of the Decoder Layer
+ hidden_states = self.quantize_output_after_block(hidden_states, quant_hidden_states, scale_hidden_states)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ _update_causal_mask = LlamaModel._update_causal_mask
+
+
+class CoatLlamaForCausalLM(CoatLlamaPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = CoatLlamaModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ forward = LlamaForCausalLM.forward
+
+ prepare_inputs_for_generation = LlamaForCausalLM.prepare_inputs_for_generation
+
+
+# TODO
+# class LlamaForSequenceClassification(LlamaPreTrainedModel):
+
+# class LlamaForQuestionAnswering(LlamaPreTrainedModel):
+
+# class LlamaForTokenClassification(LlamaPreTrainedModel):
+
+
+def make_state_dict_compatible(state_dict: dict[str, torch.Tensor]):
+ compatible_state_dict = {}
+
+ for key, value in state_dict.items():
+ if fnmatch(key, "*self_attn.q_proj*"):
+ new_key = key.replace("self_attn.q_proj", "BeforeAttention.q_proj")
+ elif fnmatch(key, "*self_attn.k_proj*"):
+ new_key = key.replace("self_attn.k_proj", "BeforeAttention.k_proj")
+ elif fnmatch(key, "*self_attn.v_proj*"):
+ new_key = key.replace("self_attn.v_proj", "BeforeAttention.v_proj")
+ elif fnmatch(key, "*self_attn.o_proj*"):
+ new_key = key.replace("self_attn.o_proj", "AfterAttention.o_proj")
+
+ elif fnmatch(key, "*mlp.gate_proj*"):
+ new_key = key.replace("mlp.gate_proj", "MLPResidual.gate_proj")
+ elif fnmatch(key, "*mlp.up_proj*"):
+ new_key = key.replace("mlp.up_proj", "MLPResidual.up_proj")
+ elif fnmatch(key, "*mlp.down_proj*"):
+ new_key = key.replace("mlp.down_proj", "MLPResidual.down_proj")
+
+ else:
+ new_key = key
+
+ compatible_state_dict[new_key] = value
+
+ return compatible_state_dict
+
+
+AutoConfig.register("fp8_llama", CoatLlamaConfig)
+AutoModel.register(CoatLlamaConfig, CoatLlamaModel)
+AutoModelForCausalLM.register(CoatLlamaConfig, CoatLlamaForCausalLM)
diff --git a/llava/model/coat/activation/models/coat_llama_convert_from_hf.py b/llava/model/coat/activation/models/coat_llama_convert_from_hf.py
new file mode 100644
index 0000000000000000000000000000000000000000..c02da9e1b4211f8becaa2a54dc30a46fcf88185b
--- /dev/null
+++ b/llava/model/coat/activation/models/coat_llama_convert_from_hf.py
@@ -0,0 +1,71 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import argparse
+import os
+from dataclasses import asdict, dataclass, field
+from typing import Optional
+
+import torch
+import transformers
+from coat.activation.models._fp8_quantization_config import QuantizationConfig
+from coat.activation.models.coat_llama import CoatLlamaConfig, CoatLlamaForCausalLM, make_state_dict_compatible
+from transformers import AutoConfig, AutoModelForCausalLM
+
+
+@dataclass
+class ConvertArguments:
+ model_name: str = field(metadata={"help": "The model name or path to download the LLaMA model"})
+ save_path: str = field(metadata={"help": "The path where the converted model weights will be saved"})
+ cache_dir: str = field(default=None, metadata={"help": "Directory to cache the model"})
+
+
+def download_and_convert_llama(convert_args: ConvertArguments, quantization_args: QuantizationConfig):
+ """
+ Downloads a LLaMA model, converts its weights using `make_state_dict_compatible`,
+ and saves the converted model.
+
+ Args:
+ model_name (str): The model name or path to download the LLaMA model.
+ save_path (str): The path where the converted model weights will be saved.
+ cache_dir (Optional[str]): Directory to cache the model. Defaults to None.
+
+ Returns:
+ None
+ """
+ model_name = convert_args.model_name
+ save_path = convert_args.save_path
+ cache_dir = convert_args.cache_dir
+
+ # Step 1: Download the original LLaMA model
+ model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
+
+ # Step 2: Initialize the model configuration for FP8 or other custom config
+ config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
+
+ # Step 3: Apply make_state_dict_compatible to convert weights
+ compatible_state_dict = make_state_dict_compatible(model.state_dict())
+
+ # Step 4: Create a new model instance with compatible configuration
+ fp8_config = CoatLlamaConfig(**config.to_dict())
+ fp8_config.coat_fp8_args = asdict(quantization_args)
+
+ converted_model = AutoModelForCausalLM.from_config(fp8_config)
+ converted_model.load_state_dict(compatible_state_dict)
+
+ # Step 5: Save the converted model and configuration using save_pretrained
+ os.makedirs(save_path, exist_ok=True)
+ converted_model.save_pretrained(save_path)
+ print(f"Converted model saved at {save_path}")
+
+
+if __name__ == "__main__":
+ # Parse command-line arguments
+ parser = transformers.HfArgumentParser((ConvertArguments, QuantizationConfig)) # NOTE: FP8
+ convert_args, quantization_args = parser.parse_args_into_dataclasses()
+
+ # Call the function with parsed arguments
+ download_and_convert_llama(convert_args, quantization_args)
diff --git a/llava/model/coat/activation/models/coat_olmo.py b/llava/model/coat/activation/models/coat_olmo.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5d2cdedd73c56a68396aebdab1a2762876abc53
--- /dev/null
+++ b/llava/model/coat/activation/models/coat_olmo.py
@@ -0,0 +1,1942 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Adapted from
+[MosaiclML](https://github.com/mosaicml/examples.git) and
+[minGPT](https://github.com/karpathy/minGPT.git)
+"""
+
+from __future__ import annotations
+
+import logging
+import math
+import sys
+from abc import abstractmethod
+from collections import defaultdict
+from functools import partial
+from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Sequence, Set, Tuple, cast
+
+import torch
+import torch.backends.cuda
+import torch.nn as nn
+import torch.nn.functional as F
+from olmo.aliases import PathOrStr
+from olmo.beam_search import BeamSearch, Constraint, FinalSequenceScorer, Sampler
+from olmo.config import (
+ ActivationCheckpointingStrategy,
+ ActivationType,
+ BlockType,
+ CheckpointType,
+ FSDPWrapStrategy,
+ InitFnType,
+ LayerNormType,
+ ModelConfig,
+ QuantActivationConfig,
+ ShardedCheckpointerType,
+ TrainConfig,
+)
+from olmo.exceptions import OLMoConfigurationError
+from olmo.initialization import init_normal
+from olmo.model import (
+ Activation,
+ BufferCache,
+ Dropout,
+ LayerNorm,
+ LayerNormBase,
+ OLMo,
+ OLMoBlock,
+ OLMoBlockGroup,
+ OLMoGenerateOutput,
+ OLMoOutput,
+ RMSLayerNorm,
+ RotaryEmbedding,
+ _non_meta_init_device,
+ activation_checkpoint_function,
+ alibi_attention_bias,
+ causal_attention_bias,
+ get_causal_attention_bias,
+ should_checkpoint_block,
+)
+from olmo.torch_util import ensure_finite_, get_cumulative_document_lengths
+from torch import einsum
+
+from ..real_quantization import (
+ Coat_quantize_bgn,
+ Coat_quantize_end,
+ fp8_add_Ifp_Ifp_Ofp_Og16,
+ fp8_add_Ifp_Ifp_Ofp_Opt,
+ fp8_division,
+ fp8_division_transpose,
+ fp8_gelu_backward,
+ fp8_gelu_forward,
+ fp8_layernorm_noparam_backward,
+ fp8_layernorm_noparam_forward,
+ fp8_linear_backward,
+ fp8_linear_forward,
+ fp8_mul_backward,
+ fp8_mul_forward,
+ fp8_quantize,
+ fp8_quantize_pertensor,
+ fp8_quantize_pertensor_transpose,
+ fp8_rmsnorm_backward,
+ fp8_rmsnorm_forward,
+ fp8_silu_backward,
+ fp8_silu_forward,
+ fp8_transpose,
+)
+from ._fp8_weightcache import FP8CacheWeightModule
+from ._fp8manager import FP8Manager
+
+if sys.version_info.minor > 8:
+ from collections.abc import MutableMapping
+elif sys.version_info.minor == 8:
+ from typing import MutableMapping
+else:
+ raise SystemExit("This script supports Python 3.8 or higher")
+
+__all__ = [
+ "LayerNormBase",
+ "LayerNorm",
+ "RMSLayerNorm",
+ "RotaryEmbedding",
+ "Activation",
+ "GELU",
+ "ReLU",
+ "SwiGLU",
+ "OLMoBlock",
+ "OLMoSequentialBlock",
+ "OLMo",
+ "OLMoOutput",
+ "OLMoGenerateOutput",
+]
+
+
+log = logging.getLogger(__name__)
+
+
+class CoatOLMoBeforeAttentionResidual(FP8CacheWeightModule):
+ """
+ This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 1 * Linear layers
+ """
+
+ def __init__(self, config: ModelConfig, qargs: QuantActivationConfig, layer_id, fused_dims: tuple):
+ super().__init__(config, qargs, layer_id)
+
+ self.qargs = qargs
+ self.fwobits = {
+ "fabit": self.qargs.fabit,
+ "fwbit": self.qargs.fwbit,
+ "fobit": self.qargs.fobit,
+ "babit": self.qargs.babit,
+ "bwbit": self.qargs.bwbit,
+ "bobit": self.qargs.bobit,
+ }
+ self.ln_normalized_shape = config.d_model
+ self.att_proj = nn.Linear(config.d_model, sum(fused_dims), bias=config.include_bias, device=config.init_device)
+
+ self.attn_norm = LayerNorm.build(config)
+
+ def forward(self, re_x, x, s):
+ if self.training:
+ if self.qargs.weight_memory_efficient:
+ # Prepare
+ with torch.no_grad():
+ weight1_s = self.prepare_weight(self.att_proj.weight, "att_proj", FP8Manager.is_first_microbatch)
+ return _CoatOLMoBeforeAttentionResidual.apply(
+ re_x,
+ x,
+ s,
+ self.att_proj.weight,
+ None,
+ None,
+ weight1_s,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ # Prepare
+ with torch.no_grad():
+ weight1, weight1_t, weight1_s = self.prepare_weight(
+ self.att_proj.weight, "att_proj", FP8Manager.is_first_microbatch
+ )
+ return _CoatOLMoBeforeAttentionResidual.apply(
+ re_x,
+ x,
+ s,
+ self.att_proj.weight,
+ weight1,
+ weight1_t,
+ weight1_s,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ return re_x, self.att_proj(self.attn_norm(re_x))
+
+
+class _CoatOLMoBeforeAttentionResidual(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ re_x,
+ in_x,
+ in_s,
+ weight1_origin,
+ weight1,
+ weight1_t,
+ weight1_s,
+ group_size,
+ fwobits,
+ layer_id,
+ config,
+ qargs,
+ eps=1e-5,
+ ):
+ # for autograd
+ if fwobits["fabit"] == "E4M3":
+ # in_x = in_x.to(torch.float8_e4m3fn)
+ in_x = in_x.view(torch.float8_e4m3fn)
+ else:
+ raise ValueError("fabit should be E4M3")
+
+ # LayerNorm
+ ln_x, ln_s, ln_x_t, ln_utils = fp8_layernorm_noparam_forward(
+ in_x, in_s, group_size, eps, transpose_output_2d=True
+ )
+
+ # Linear Layer QKV Projection
+ if qargs.weight_memory_efficient:
+ assert weight1 is None # memory efficient
+ weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s)
+ fc1_x = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, False, group_size)
+
+ # ==================== save for backward ====================
+ ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s)
+ if qargs.weight_memory_efficient:
+ assert weight1_t is None
+ ctx.weight = weight1_origin, weight1_s
+ else:
+ ctx.weight = weight1_t, weight1_s
+ ctx.group_size = group_size
+ ctx.ln_utils = ln_utils
+ ctx.utils = fwobits, layer_id, config, qargs
+
+ return re_x, fc1_x
+
+ @staticmethod
+ def backward(ctx, fp_grad, flash_g):
+ in_x, in_s, ln_x_t, ln_s = ctx.saved_tensors
+ weight1_t, weight1_s = ctx.weight
+ group_size = ctx.group_size
+ mean, rstd, num_warps = ctx.ln_utils
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ # ==================== Begin backward ====================
+ # Quantize the RoPE and FlashAttention Output. grad_input and grad_weight requires different data layout.
+ flash_g, flash_gs, flash_g_t = fp8_quantize_pertensor_transpose(
+ flash_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False
+ )
+
+ # Linear Layer QKV Projection
+ if qargs.weight_memory_efficient:
+ weight1_t, weight1_s = fp8_division_transpose(
+ weight1_t, qargs.group_size, fwobits["fwbit"], weight1_s, only_transposed=True
+ )
+ fc1_g, att_proj_wg = fp8_linear_backward(
+ ln_x_t, ln_s, flash_g, flash_gs, flash_g_t, weight1_t, weight1_s, group_size
+ )
+
+ # LayerNorm
+ in_g = fp8_layernorm_noparam_backward(in_x, in_s, fc1_g, group_size, mean, rstd, num_warps)
+
+ # Add the gradient together, and prepare the input of the next layer.
+ re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt(
+ fp_grad, in_g, group_size, fwobits["babit"], stochastic=False
+ )
+
+ # for autograd. forward's data type should be the same of backward tensor. this will not change the actual binary representation.
+ in_g = in_g.view(torch.float8_e4m3fn)
+
+ # Although the next operator is a linear layer in MLPResidual module, we return in_sg_g16 to make the size compatible with the forward. Otherwise it will not pass autograd.
+ return re_g, in_g, in_sg_g16, att_proj_wg, None, None, None, None, None, None, None, None, None
+
+
+class CoatOLMoAfterAttentionResidual(FP8CacheWeightModule):
+ """
+ This is a typical transformer attention module that contains (1) Residual (2) 1 * Linear layers
+ """
+
+ def __init__(self, config: ModelConfig, qargs: QuantActivationConfig, layer_id):
+ super().__init__(config, qargs, layer_id)
+
+ self.qargs = qargs
+ self.fwobits = {
+ "fabit": self.qargs.fabit,
+ "fwbit": self.qargs.fwbit,
+ "fobit": self.qargs.fobit,
+ "babit": self.qargs.babit,
+ "bwbit": self.qargs.bwbit,
+ "bobit": self.qargs.bobit,
+ }
+ self.attn_out = nn.Linear(config.d_model, config.d_model, bias=config.include_bias, device=config.init_device)
+
+ def forward(self, re_x, in_x):
+ if self.training:
+ if self.qargs.weight_memory_efficient:
+ # prepare for the weight
+ with torch.no_grad():
+ weight2_s = self.prepare_weight(self.attn_out.weight, "attn_out", FP8Manager.is_first_microbatch)
+
+ return _CoatOLMoAfterAttentionResidual.apply(
+ re_x,
+ in_x,
+ self.attn_out.weight,
+ None,
+ None,
+ weight2_s,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ # prepare for the weight
+ with torch.no_grad():
+ weight2, weight2_t, weight2_s = self.prepare_weight(
+ self.attn_out.weight, "attn_out", FP8Manager.is_first_microbatch
+ )
+
+ return _CoatOLMoAfterAttentionResidual.apply(
+ re_x,
+ in_x,
+ self.attn_out.weight,
+ weight2,
+ weight2_t,
+ weight2_s,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ return re_x + self.attn_out(in_x), None, None
+
+
+class _CoatOLMoAfterAttentionResidual(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx, re_x, flash_x, weight2_origin, weight2, weight2_t, weight2_s, group_size, fwobits, layer_id, config, qargs
+ ):
+ # Quantize the FlashAttention Output
+ flash_qx, flash_s, _ = fp8_quantize_pertensor(
+ flash_x, group_size, fwobits["fabit"]
+ ) # Modified to make it memory efficient
+
+ # # Attention Projection Linear Layer
+ if qargs.weight_memory_efficient:
+ assert weight2 is None # memory efficient
+ weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s)
+ fc2_x = fp8_linear_forward(flash_qx, flash_s, weight2, weight2_s, False, group_size) #
+
+ # import IPython
+ # IPython.embed()
+ # Add the activations together
+ fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc2_x, flash_qx.dtype, group_size)
+
+ # ==================== save for backward ====================
+ ctx.save_for_backward(flash_x, flash_s)
+ if qargs.weight_memory_efficient:
+ assert weight2_t is None
+ ctx.weight = weight2_origin, weight2_s
+ else:
+ ctx.weight = weight2_t, weight2_s
+ ctx.group_size = group_size
+ ctx.fwobits = fwobits
+ ctx.utils = fwobits, layer_id, config, qargs
+
+ # For autograd
+ out_x = out_x.view(torch.float8_e4m3fn)
+
+ return fp_x, out_x, out_s
+
+ @staticmethod
+ def backward(ctx, fp_grad, out_g, out_gs):
+ flash_x, flash_s = ctx.saved_tensors
+ weight2_t, weight2_s = ctx.weight
+ group_size = ctx.group_size
+ fwobits = ctx.fwobits
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ # for autograd
+ if fwobits["babit"] == "E5M2":
+ # out_g = out_g.to(torch.float8_e5m2)
+ out_g = out_g.view(torch.float8_e5m2)
+ else:
+ raise ValueError("babit should be E5M2")
+ out_gs_max = out_gs.max()
+
+ # ==================== Begin backward ====================
+ # Output Projection
+ out_g_t = fp8_transpose(out_g, transpose_output_2d=True)
+
+ # We do not save an extra flash_x to save the memory usage
+ flash_x_t, flash_s = fp8_division_transpose(
+ flash_x, group_size, fwobits["fabit"], flash_s, stochastic=False, only_transposed=True
+ )
+
+ if qargs.weight_memory_efficient:
+ weight2_t, weight2_s = fp8_division_transpose(
+ weight2_t, qargs.group_size, fwobits["fwbit"], weight2_s, only_transposed=True
+ )
+ fc2_g, attn_out_wg = fp8_linear_backward(
+ flash_x_t, flash_s, out_g, out_gs_max, out_g_t, weight2_t, weight2_s, group_size
+ )
+
+ return fp_grad, fc2_g, attn_out_wg, None, None, None, None, None, None, None, None
+
+
+class CoatOLMoMLPResidual(FP8CacheWeightModule):
+ """
+ This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 2 / 3 * Linear layers
+ (4) GELU / Silu Activation
+ """
+
+ def __init__(self, config: ModelConfig, qargs: QuantActivationConfig, layer_id, hidden_size: int):
+ super().__init__(config, qargs, layer_id)
+
+ self.qargs = qargs
+ self.fwobits = {
+ "fabit": self.qargs.fabit,
+ "fwbit": self.qargs.fwbit,
+ "fobit": self.qargs.fobit,
+ "babit": self.qargs.babit,
+ "bwbit": self.qargs.bwbit,
+ "bobit": self.qargs.bobit,
+ }
+ self.ln_normalized_shape = config.d_model
+ self.act_output_multiplier = 0.5 if config.activation_type == ActivationType.swiglu else 1
+ self.ff_proj = nn.Linear(config.d_model, hidden_size, bias=config.include_bias, device=config.init_device)
+ self.ff_out = nn.Linear(
+ int(self.act_output_multiplier * hidden_size),
+ config.d_model,
+ bias=config.include_bias,
+ device=config.init_device,
+ )
+ self.training = True
+
+ # below is only used when training = False
+ self.ff_norm = LayerNorm.build(config)
+ self.act = Activation.build(config)
+ assert (self.act.output_multiplier * hidden_size) % 1 == 0
+
+ def forward(self, re_x, x, s):
+ if self.training:
+ if self.qargs.weight_memory_efficient: # prepare for the weight
+ with torch.no_grad():
+ weight1_s = self.prepare_weight(self.ff_proj.weight, "ff_proj", FP8Manager.is_first_microbatch)
+ weight2_s = self.prepare_weight(self.ff_out.weight, "ff_out", FP8Manager.is_first_microbatch)
+
+ return _CoatOLMoMLPResidual.apply(
+ re_x,
+ x,
+ s,
+ self.ff_proj.weight,
+ None,
+ None,
+ weight1_s,
+ self.ff_out.weight,
+ None,
+ None,
+ weight2_s,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ # prepare for the weight
+ with torch.no_grad():
+ weight1, weight1_t, weight1_s = self.prepare_weight(
+ self.ff_proj.weight, "ff_proj", FP8Manager.is_first_microbatch
+ )
+ weight2, weight2_t, weight2_s = self.prepare_weight(
+ self.ff_out.weight, "ff_out", FP8Manager.is_first_microbatch
+ )
+
+ return _CoatOLMoMLPResidual.apply(
+ re_x,
+ x,
+ s,
+ self.ff_proj.weight,
+ weight1,
+ weight1_t,
+ weight1_s,
+ self.ff_out.weight,
+ weight2,
+ weight2_t,
+ weight2_s,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ og_x = re_x
+ re_x = self.ff_norm(re_x)
+ re_x = self.ff_proj(re_x)
+ re_x = self.act(re_x)
+ re_x = self.ff_out(re_x)
+ re_x = og_x + re_x
+ return re_x, None, None
+
+
+class _CoatOLMoMLPResidual(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ re_x,
+ in_x,
+ in_s,
+ weight1_origin,
+ weight1,
+ weight1_t,
+ weight1_s,
+ weight2_origin,
+ weight2,
+ weight2_t,
+ weight2_s,
+ group_size,
+ fwobits,
+ layer_id,
+ config,
+ qargs,
+ eps=1e-5,
+ ):
+ # For autograd
+ if fwobits["fabit"] == "E4M3":
+ # in_x = in_x.to(torch.float8_e4m3fn)
+ in_x = in_x.view(torch.float8_e4m3fn)
+ else:
+ raise ValueError("fabit should be E4M3")
+
+ # LayerNorm
+ ln_x, ln_s, ln_x_t, ln_utils = fp8_layernorm_noparam_forward(
+ in_x, in_s, group_size, eps, transpose_output_2d=True
+ )
+
+ # Linear Layer of Up Projection and Gate Projection. They are fused as one linear layer.
+ if qargs.weight_memory_efficient:
+ assert weight1 is None # memory efficient
+ weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s)
+ fc1_x, fc1_s = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, True, group_size)
+
+ # NOTE: Becareful of the order
+ up_x, gate_x = fc1_x.chunk(2, dim=-1)
+ up_s, gate_s = fc1_s.chunk(2, dim=-1)
+
+ # silu Activation
+ silu_x, silu_s = fp8_silu_forward(gate_x, gate_s, group_size)
+
+ # Element-wise Multiplication
+ mul_x, mul_s, mul_x_t = fp8_mul_forward(silu_x, silu_s, up_x, up_s, group_size, transpose_output_2d=True)
+
+ # Output Projection
+ if weight2 is None: # memory efficient
+ weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s)
+ fc2_x = fp8_linear_forward(mul_x, mul_s, weight2, weight2_s, False, group_size)
+
+ # Add the activation together
+ fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc2_x, mul_x.dtype, group_size)
+
+ # ==================== save for backward ====================
+ ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s)
+
+ ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s)
+ if (
+ qargs.weight_memory_efficient
+ ): # Weight_1/2_origin will not be saved twice, so it will be more memory efficient.
+ assert weight1_t is None
+ ctx.weight = (weight1_origin, weight1_s, weight2_origin, weight2_s)
+ else: # Weight1/2_t is different from the origin weight, so saving it will consumes additional memory footprint.
+ ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s)
+
+ ctx.group_size = group_size
+ ctx.ln_utils = ln_utils
+ ctx.utils = fwobits, layer_id, config, qargs
+
+ out_x = out_x.view(torch.float8_e4m3fn)
+
+ return fp_x, out_x, out_s
+
+ @staticmethod
+ def backward(ctx, fp_grad, out_g, out_gs):
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s = ctx.saved_tensors
+
+ (weight1_t, weight1_s, weight2_t, weight2_s) = ctx.weight
+ group_size = ctx.group_size
+ mean, rstd, num_warps = ctx.ln_utils
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ # For autograd
+ if fwobits["babit"] == "E5M2":
+ # out_g = out_g.to(torch.float8_e5m2)
+ out_g = out_g.view(torch.float8_e5m2)
+ else:
+ raise ValueError("babit should be E5M2")
+ out_gs_max = out_gs.max()
+
+ # ==================== Begin backward ====================
+ # Output Projection
+ out_gs = out_gs.max()
+ out_g_t = fp8_transpose(out_g, transpose_output_2d=True)
+
+ if qargs.weight_memory_efficient:
+ weight2_t, weight2_s = fp8_division_transpose(
+ weight2_t, qargs.group_size, fwobits["fwbit"], weight2_s, only_transposed=True
+ )
+ fc2_g, weight2_grad = fp8_linear_backward(
+ mul_x_t, mul_s, out_g, out_gs_max, out_g_t, weight2_t, weight2_s, group_size
+ )
+
+ # [MEM TEST]
+ del out_g, out_g_t, weight2_t
+
+ # Element-wise Multiplication, 1 means gate, 2 means up
+ mul_g1, (mul_g2, mul_gs2) = fp8_mul_backward(silu_x, silu_s, up_x, up_s, fc2_g, group_size, fwobits["babit"])
+
+ # Silu activation
+ silu_g, silu_gs = fp8_silu_backward(gate_x, gate_s, mul_g1, group_size, fwobits["babit"])
+
+ # Prepare the input of Linear Layer. NOTE: Becareful of the order
+ gateup_g = torch.cat([mul_g2, silu_g], dim=-1)
+ gateup_gs = torch.cat([mul_gs2, silu_gs])
+ gateup_gs = torch.max(gateup_gs)
+
+ gateup_g, gateup_gs, gateup_g_t = fp8_division_transpose(
+ gateup_g, group_size, fwobits["babit"], gateup_gs, stochastic=False
+ )
+
+ # Linear Layer of Up and Gate Projection
+ if qargs.weight_memory_efficient:
+ weight1_t, weight1_s = fp8_division_transpose(
+ weight1_t, group_size, fwobits["fwbit"], weight1_s, only_transposed=True
+ )
+ fc1_g, weight1_grad = fp8_linear_backward(
+ ln_x_t, ln_s, gateup_g, gateup_gs, gateup_g_t, weight1_t, weight1_s, group_size
+ )
+
+ # layerNorm
+ in_g = fp8_layernorm_noparam_backward(in_x, in_s, fc1_g, group_size, mean, rstd, num_warps)
+
+ # Add the gradient together
+ re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt(
+ fp_grad, in_g, group_size, fwobits["babit"], stochastic=False
+ )
+
+ in_g = in_g.view(torch.float8_e4m3fn)
+
+ return (
+ re_g,
+ in_g,
+ in_sg_g16,
+ weight1_grad,
+ None,
+ None,
+ None,
+ weight2_grad,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+class CoatOLMoBlock(nn.Module):
+ """
+ A base class for transformer block implementations.
+ """
+
+ def __init__(self, layer_id: int, config: ModelConfig, qargs: QuantActivationConfig, cache: BufferCache):
+ super().__init__()
+ self.layer_id = layer_id
+ self.config = config
+ self.qargs = qargs
+ self.hidden_size = (
+ config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
+ )
+ self.__cache = cache
+ assert config.d_model % config.n_heads == 0
+
+ self._activation_checkpoint_fn: Callable | None = None
+
+ # Dropout.
+ self.dropout = Dropout(config.residual_dropout)
+
+ # Layer norms.
+ self.k_norm: LayerNormBase | None = None
+ self.q_norm: LayerNormBase | None = None
+ if config.attention_layer_norm:
+ assert config.effective_n_kv_heads is not None
+ self.k_norm = LayerNormBase.build(
+ config,
+ size=(config.d_model // config.n_heads) * config.effective_n_kv_heads,
+ elementwise_affine=config.attention_layer_norm_with_affine,
+ )
+ self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)
+
+ # Make sure QKV clip coefficient is positive, otherwise it's not well-defined.
+ if config.clip_qkv is not None:
+ assert config.clip_qkv > 0
+
+ # Activation function.
+ self.act = Activation.build(config)
+ assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
+
+ if not self.qargs.use_quantize_model:
+ # Attention output projection.
+ self.attn_out = nn.Linear(
+ config.d_model, config.d_model, bias=config.include_bias, device=config.init_device
+ )
+
+ # Feed-forward output projection.
+ self.ff_out = nn.Linear(
+ int(self.act.output_multiplier * self.hidden_size),
+ config.d_model,
+ bias=config.include_bias,
+ device=config.init_device,
+ )
+ self.ff_out._is_residual = True # type: ignore
+
+ # Rotary embeddings.
+ if self.config.rope:
+ self.rotary_emb = RotaryEmbedding(config, self.__cache)
+
+ self.flash_attn_func = None
+ self.flash_attn_varlen_func = None
+ if config.flash_attention:
+ try:
+ from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
+
+ self.flash_attn_func = flash_attn_func
+ self.flash_attn_varlen_func = flash_attn_varlen_func
+ except ModuleNotFoundError:
+ pass
+
+ def reset_parameters(self):
+ if self.k_norm is not None:
+ self.k_norm.reset_parameters()
+ if self.q_norm is not None:
+ self.q_norm.reset_parameters()
+
+ if not self.qargs.use_quantize_model:
+ if self.config.init_fn == InitFnType.normal:
+ attn_out_std = ff_out_std = self.config.init_std
+ cutoff_factor = self.config.init_cutoff_factor
+
+ elif self.config.init_fn == InitFnType.mitchell:
+ attn_out_std = 1 / (math.sqrt(2 * self.config.d_model * (self.layer_id + 1)))
+ ff_out_std = 1 / (math.sqrt(2 * self.ff_out.in_features * (self.layer_id + 1)))
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
+
+ elif self.config.init_fn == InitFnType.full_megatron:
+ attn_out_std = ff_out_std = self.config.init_std / math.sqrt(2.0 * self.config.n_layers)
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
+
+ else:
+ raise NotImplementedError(self.config.init_fn)
+
+ init_normal(self.attn_out, std=attn_out_std, init_cutoff_factor=cutoff_factor)
+ init_normal(self.ff_out, std=ff_out_std, init_cutoff_factor=cutoff_factor)
+
+ def set_activation_checkpointing(
+ self, strategy: ActivationCheckpointingStrategy | None, checkpoint_func: Callable | None = None
+ ):
+ if strategy == ActivationCheckpointingStrategy.fine_grained:
+ self._activation_checkpoint_fn = checkpoint_func or activation_checkpoint_function(self.config)
+ else:
+ self._activation_checkpoint_fn = None
+
+ @classmethod
+ def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
+ target_dtype = input_dtype
+ # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
+ # `is_autocast_cpu_enabled()` for CPU autocast.
+ # See https://github.com/pytorch/pytorch/issues/110966.
+ if bias.device.type == "cuda" and torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled():
+ target_dtype = torch.get_autocast_cpu_dtype()
+ if bias.dtype != target_dtype:
+ bias = bias.to(target_dtype)
+ ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
+ return bias
+
+ def _scaled_dot_product_attention(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ attn_mask: torch.Tensor | None = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ max_doc_len: int | None = None,
+ cu_doc_lens: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ """
+ Computes scaled dot product attention on query, key and value tensors, using an optional
+ attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.
+ """
+ if max_doc_len is not None and cu_doc_lens is not None:
+ assert self.flash_attn_varlen_func is not None, "flash-attn is required for document masking"
+ assert attn_mask is None, "attn-mask is currently not supported with document masking"
+ B, T, D = q.size(0), q.size(2), q.size(3)
+ r = self.flash_attn_varlen_func(
+ q.transpose(1, 2).view(B * T, -1, D),
+ k.transpose(1, 2).view(B * T, -1, D),
+ v.transpose(1, 2).view(B * T, -1, D),
+ cu_doc_lens,
+ cu_doc_lens,
+ max_doc_len,
+ max_doc_len,
+ dropout_p=dropout_p,
+ causal=is_causal,
+ )
+ return r.view(B, T, -1, D).transpose(1, 2)
+ elif self.flash_attn_func is not None and attn_mask is None:
+ r = self.flash_attn_func(
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=is_causal
+ )
+ return r.transpose(1, 2)
+ else:
+ # torch's sdpa doesn't support GQA, so we're doing this
+ assert k.size(1) == v.size(1)
+ num_kv_heads = k.size(1)
+ num_q_heads = q.size(1)
+ if num_q_heads != num_kv_heads:
+ assert num_q_heads % num_kv_heads == 0
+ k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
+ v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
+
+ return F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ )
+
+ def attention(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ attention_bias: torch.Tensor | None = None,
+ layer_past: tuple[torch.Tensor, torch.Tensor] | None = None,
+ use_cache: bool = False,
+ max_doc_len: int | None = None,
+ cu_doc_lens: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
+ B, T, C = q.size() # batch size, sequence length, d_model
+ dtype = k.dtype
+
+ # Optionally apply layer norm to keys and queries.
+ if self.q_norm is not None and self.k_norm is not None:
+ q = self.q_norm(q).to(dtype=dtype)
+ k = self.k_norm(k).to(dtype=dtype)
+
+ # Move head forward to be next to the batch dim.
+ # shape: (B, nh, T, hs)
+ q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
+ # shape: (B, n_kv_h, T, hs)
+ k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
+ # shape: (B, n_kv_h, T, hs)
+ v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
+
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ k = torch.cat((past_key, k), dim=-2)
+ v = torch.cat((past_value, v), dim=-2)
+
+ present = (k, v) if use_cache else None
+ query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
+
+ if self.config.rope:
+ # Apply rotary embeddings.
+ q, k = self.rotary_emb(q, k)
+
+ if attention_bias is not None:
+ # Resize and cast attention bias.
+ # The current dtype of the attention bias might not match the dtype that the SDP attn function will
+ # run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding
+ # as down-casting the attention bias to the autocast precision will result in -infs, which will
+ # cause the SDP attn function to produce NaNs.
+ attention_bias = self._cast_attn_bias(attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype)
+
+ # Get the attention scores.
+ # shape: (B, nh, T, hs)
+ att = self._scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=attention_bias,
+ dropout_p=0.0 if not self.training else self.config.attention_dropout,
+ is_causal=attention_bias is None,
+ max_doc_len=max_doc_len,
+ cu_doc_lens=cu_doc_lens,
+ )
+
+ # Re-assemble all head outputs side-by-side.
+ att = att.transpose(1, 2).contiguous().view(B, T, C)
+
+ # Apply output projection. NOTE: We move the attn output outside of this attention function
+ return att, present
+
+ @abstractmethod
+ def forward(
+ self,
+ x: torch.Tensor,
+ attention_bias: torch.FloatTensor | None = None,
+ layer_past: tuple[torch.Tensor, torch.Tensor] | None = None,
+ use_cache: bool = False,
+ max_doc_len: int | None = None,
+ cu_doc_lens: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
+ raise NotImplementedError
+
+ @classmethod
+ def build(cls, layer_id: int, config: ModelConfig, qargs: QuantActivationConfig, cache: BufferCache) -> OLMoBlock:
+ if config.block_type == BlockType.sequential:
+ return CoatOLMoSequentialBlock(layer_id, config, qargs, cache)
+ elif config.block_type == BlockType.llama:
+ return CoatOLMoLlamaBlock(layer_id, config, qargs, cache)
+ else:
+ raise NotImplementedError(f"Unknown block type: '{config.block_type}'")
+
+
+class CoatOLMoSequentialBlock(CoatOLMoBlock):
+ """
+ This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
+ (plus another skip connection). To compute it as ``LN(MLP(x + LN(Attention(x))))``,
+ use the flag `norm_after`.
+ """
+
+ def __init__(self, layer_id: int, config: ModelConfig, qargs: QuantActivationConfig, cache: BufferCache):
+ super().__init__(layer_id, config, qargs, cache)
+ # Attention input projection. Projects x -> (q, k, v)
+
+ assert not self.config.norm_after, "COAT currently does not support PostNorm"
+
+ head_dim = config.d_model // config.n_heads
+ self.fused_dims = (
+ config.d_model,
+ config.effective_n_kv_heads * head_dim,
+ config.effective_n_kv_heads * head_dim,
+ )
+
+ if self.qargs.use_quantize_model:
+ self.BeforeAttention = CoatOLMoBeforeAttentionResidual(config, qargs, self.layer_id, self.fused_dims)
+ self.AfterAttention = CoatOLMoAfterAttentionResidual(config, qargs, self.layer_id)
+ self.MLPResidual = CoatOLMoMLPResidual(config, qargs, self.layer_id, self.hidden_size)
+ else:
+ self.att_proj = nn.Linear(
+ config.d_model, sum(self.fused_dims), bias=config.include_bias, device=config.init_device
+ )
+ # Feed-forward input projection.
+ self.ff_proj = nn.Linear(
+ config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
+ )
+
+ # Layer norms.
+ self.attn_norm = LayerNorm.build(config, size=config.d_model)
+ self.ff_norm = LayerNorm.build(config, size=config.d_model)
+
+ def reset_parameters(self):
+ super().reset_parameters()
+ self.attn_norm.reset_parameters()
+ self.ff_norm.reset_parameters()
+ # NOTE: the standard deviation for these weights does not depend on the layer.
+
+ if self.qargs.use_quantize_model: # The initialization appears here, not in CoatOLMoBlock's reset_parameters
+ if self.config.init_fn == InitFnType.normal:
+ attn_out_std = ff_out_std = self.config.init_std
+ cutoff_factor = self.config.init_cutoff_factor
+
+ elif self.config.init_fn == InitFnType.mitchell:
+ attn_out_std = 1 / (math.sqrt(2 * self.config.d_model * (self.layer_id + 1)))
+ ff_out_std = 1 / (math.sqrt(2 * self.MLPResidual.ff_out.in_features * (self.layer_id + 1)))
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
+
+ elif self.config.init_fn == InitFnType.full_megatron:
+ attn_out_std = ff_out_std = self.config.init_std / math.sqrt(2.0 * self.config.n_layers)
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
+
+ else:
+ raise NotImplementedError(self.config.init_fn)
+
+ init_normal(self.AfterAttention.attn_out, std=attn_out_std, init_cutoff_factor=cutoff_factor)
+ init_normal(self.MLPResidual.ff_out, std=ff_out_std, init_cutoff_factor=cutoff_factor)
+
+ if self.config.init_fn == InitFnType.normal:
+ std = self.config.init_std
+ cutoff_factor = self.config.init_cutoff_factor
+ elif self.config.init_fn == InitFnType.mitchell:
+ std = 1 / math.sqrt(self.config.d_model)
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
+ elif self.config.init_fn == InitFnType.full_megatron:
+ std = self.config.init_std
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
+ else:
+ raise NotImplementedError(self.config.init_fn)
+
+ if not self.qargs.use_quantize_model:
+ init_normal(self.att_proj, std, cutoff_factor)
+ init_normal(self.ff_proj, std, cutoff_factor)
+ else:
+ init_normal(self.BeforeAttention.att_proj, std, cutoff_factor)
+ init_normal(self.MLPResidual.ff_proj, std, cutoff_factor)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ qx: torch.Tensor,
+ sx: torch.Tensor,
+ attention_bias: torch.Tensor | None = None,
+ layer_past: tuple[torch.Tensor, torch.Tensor] | None = None,
+ use_cache: bool = False,
+ max_doc_len: int | None = None,
+ cu_doc_lens: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
+ # Get query, key, value projections.
+ # shape:
+ # - for regular attn q, k, v: (batch_size, seq_len, d_model)
+ # - for multi-query attn q: (batch_size, seq_len, d_model)
+ # k, v: (batch_size, seq_len, d_model // n_heads)
+ # - for group query attn q: (batch_size, seq_len, d_model)
+ # k, v: (batch_size, seq_len, d_model // n_kv_heads)
+
+ # import IPython
+ # IPython.embed()
+
+ if self.qargs.use_quantize_model:
+ # if False:
+ x, qkv = self.BeforeAttention(x, qx, sx)
+ else:
+ # apply norm before
+ h = self.attn_norm(x)
+
+ qkv = self.BeforeAttention.att_proj(h)
+
+ if self.config.clip_qkv is not None:
+ qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+
+ q, k, v = qkv.split(self.fused_dims, dim=-1)
+
+ # Get attention scores.
+ att, cache = self.attention(
+ q,
+ k,
+ v,
+ attention_bias,
+ layer_past=layer_past,
+ use_cache=use_cache,
+ max_doc_len=max_doc_len,
+ cu_doc_lens=cu_doc_lens,
+ )
+
+ # import IPython
+ # IPython.embed()
+ if self.qargs.use_quantize_model:
+ # if False:
+ x, qx, sx = self.AfterAttention(x, att)
+ else:
+ att = self.AfterAttention.attn_out(att)
+
+ # Add attention scores.
+ # shape: (B, T, C)
+ x = x + self.dropout(att)
+
+ if self.qargs.use_quantize_model:
+ # if False:
+ x, qx, sx = self.MLPResidual(x, qx, sx)
+ else:
+ # Add feed-forward projection.
+ # shape: (batch_size, seq_len, d_model)
+ og_x = x
+
+ x = self.ff_norm(x)
+
+ x = self.MLPResidual.ff_proj(x)
+
+ if self._activation_checkpoint_fn is not None:
+ x = self._activation_checkpoint_fn(self.act, x) # type: ignore
+ else:
+ x = self.act(x)
+ x = self.MLPResidual.ff_out(x)
+
+ x = self.dropout(x)
+ x = og_x + x
+
+ # import IPython
+ # IPython.embed()
+
+ return x, qx, sx, cache
+
+
+class CoatOLMoLlamaBlock(OLMoBlock):
+ """
+ This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
+ (plus another skip connection). This block is similar to `OLMoSequentialBlock`
+ but some operations have slightly different implementations to imitate the
+ behavior of Llama.
+ """
+
+ def __init__(self, layer_id: int, config: ModelConfig, qargs: QuantActivationConfig, cache: BufferCache):
+ super().__init__(layer_id, config, qargs, cache)
+ # Layer norms.
+ self.attn_norm = LayerNorm.build(config)
+ self.ff_norm = LayerNorm.build(config)
+ self.__cache = cache
+
+ # Attention input projection. Projects x -> (q, k, v)
+ if config.multi_query_attention:
+ q_proj_out_dim = config.d_model
+ k_proj_out_dim = config.d_model // config.n_heads
+ v_proj_out_dim = config.d_model // config.n_heads
+ else:
+ q_proj_out_dim = config.d_model
+ k_proj_out_dim = config.d_model
+ v_proj_out_dim = config.d_model
+ self.q_proj = nn.Linear(config.d_model, q_proj_out_dim, bias=config.include_bias, device=config.init_device)
+ self.k_proj = nn.Linear(config.d_model, k_proj_out_dim, bias=config.include_bias, device=config.init_device)
+ self.v_proj = nn.Linear(config.d_model, v_proj_out_dim, bias=config.include_bias, device=config.init_device)
+
+ # Feed-forward input projection.
+ self.ff_proj = nn.Linear(config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device)
+
+ def reset_parameters(self):
+ super().reset_parameters()
+ self.attn_norm.reset_parameters()
+ self.ff_norm.reset_parameters()
+ # NOTE: the standard deviation for these weights does not depend on the layer.
+
+ if self.config.init_fn == InitFnType.normal:
+ std = self.config.init_std
+ cutoff_factor = self.config.init_cutoff_factor
+ elif self.config.init_fn == InitFnType.mitchell:
+ std = 1 / math.sqrt(self.config.d_model)
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
+ elif self.config.init_fn == InitFnType.full_megatron:
+ std = self.config.init_std
+ cutoff_factor = self.config.init_cutoff_factor or 3.0
+ else:
+ raise NotImplementedError(self.config.init_fn)
+
+ init_normal(self.q_proj, std, cutoff_factor)
+ init_normal(self.k_proj, std, cutoff_factor)
+ init_normal(self.v_proj, std, cutoff_factor)
+ init_normal(self.ff_proj, std, cutoff_factor)
+
+ def _scaled_dot_product_attention(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ attn_mask: torch.Tensor | None = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ max_doc_len: int | None = None,
+ cu_doc_lens: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ if max_doc_len is not None or cu_doc_lens is not None:
+ raise NotImplementedError(f"attention document masking is not implemented for {self.__class__.__name__}")
+
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
+
+ if is_causal:
+ assert attn_mask is None
+
+ query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
+ attn_bias = get_causal_attention_bias(self.__cache, key_len, q.device)[:, :, :query_len, :key_len]
+ elif attn_mask is not None:
+ attn_bias = attn_mask.to(q.dtype)
+ else:
+ attn_bias = torch.zeros_like(attn_weights)
+
+ attn_weights += attn_bias
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(q.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout_p)
+ return torch.matmul(attn_weights, v)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ qx: torch.Tensor,
+ sx: torch.Tensor,
+ attention_bias: torch.Tensor | None = None,
+ layer_past: tuple[torch.Tensor, torch.Tensor] | None = None,
+ use_cache: bool = False,
+ max_doc_len: int | None = None,
+ cu_doc_lens: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
+ # Get query, key, value projections.
+ # shape:
+ # - for regular attn q, k, v: (batch_size, seq_len, d_model)
+ # - for multi-query attn q: (batch_size, seq_len, d_model)
+ # k, v: (batch_size, seq_len, d_model // n_heads)
+ x_normed = self.attn_norm(x)
+ q = self.q_proj(x_normed)
+ k = self.k_proj(x_normed)
+ v = self.v_proj(x_normed)
+
+ if self.config.clip_qkv is not None:
+ q.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+ k.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+ v.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+
+ # Get attention scores.
+ att, cache = self.attention(
+ q,
+ k,
+ v,
+ attention_bias,
+ layer_past=layer_past,
+ use_cache=use_cache,
+ max_doc_len=max_doc_len,
+ cu_doc_lens=cu_doc_lens,
+ )
+
+ att = self.attn_out(att) # NOTE: we move the attn_out outside the self.attention module
+
+ # Add attention scores.
+ # shape: (B, T, C)
+ x = x + self.dropout(att)
+
+ # Add feed-forward projection.
+ # shape: (batch_size, seq_len, d_model)
+ og_x = x
+ if self._activation_checkpoint_fn is not None:
+ x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
+ else:
+ x = self.ff_norm(x)
+ x = self.ff_proj(x)
+ if self._activation_checkpoint_fn is not None:
+ x = self._activation_checkpoint_fn(self.act, x) # type: ignore
+ else:
+ x = self.act(x)
+ x = self.ff_out(x)
+ x = self.dropout(x)
+ x = og_x + x
+
+ return x, cache
+
+
+class CoatOLMoBlockGroup(nn.ModuleList):
+ def __init__(self, config: ModelConfig, layer_offset: int, modules: Iterable[nn.Module] | None = None):
+ super().__init__(modules)
+ self.config = config
+ self.layer_offset = layer_offset
+ self.activation_checkpointing_strategy: ActivationCheckpointingStrategy | None = None
+ self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ attention_bias: torch.FloatTensor | None = None,
+ layers_past: list[tuple[torch.Tensor, torch.Tensor]] | None = None,
+ use_cache: bool = False,
+ max_doc_len: int | None = None,
+ cu_doc_lens: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]] | None]:
+ attn_key_values: list[tuple[torch.Tensor, torch.Tensor]] | None = [] if use_cache else None
+ for block_idx, block in enumerate(self):
+ layer_past = None if layers_past is None else layers_past[block_idx]
+ block_idx += self.layer_offset
+ if should_checkpoint_block(self.activation_checkpointing_strategy, block_idx):
+ # shape: (batch_size, seq_len, d_model)
+ x, cache = self._activation_checkpoint_fn( # type: ignore
+ block,
+ x,
+ attention_bias=attention_bias,
+ layer_past=layer_past,
+ use_cache=use_cache,
+ max_doc_len=max_doc_len,
+ cu_doc_lens=cu_doc_lens,
+ )
+ else:
+ # shape: (batch_size, seq_len, d_model)
+ x, cache = block(
+ x,
+ attention_bias=attention_bias,
+ layer_past=layer_past,
+ use_cache=use_cache,
+ max_doc_len=max_doc_len,
+ cu_doc_lens=cu_doc_lens,
+ )
+ if attn_key_values is not None:
+ assert cache is not None
+ attn_key_values.append(cache)
+ return x, attn_key_values
+
+ def reset_parameters(self):
+ for block in self:
+ block.reset_parameters()
+
+ def set_activation_checkpointing(
+ self, strategy: ActivationCheckpointingStrategy | None, checkpoint_func: Callable | None = None
+ ):
+ self.activation_checkpointing_strategy = strategy
+ for block in self:
+ block.set_activation_checkpointing(strategy, checkpoint_func=checkpoint_func)
+
+
+class CoatOLMo(nn.Module):
+ def __init__(self, config: ModelConfig, qargs: QuantActivationConfig, init_params: bool = True):
+ super().__init__()
+ self.config = config
+ self.qargs = qargs
+ self.__cache = BufferCache()
+
+ # Validate config.
+ if self.config.alibi and self.config.flash_attention:
+ raise OLMoConfigurationError("ALiBi is currently not supported with FlashAttention")
+
+ if self.config.alibi and self.config.rope:
+ raise OLMoConfigurationError("ALiBi and RoPE are mutually exclusive")
+
+ if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size:
+ if self.config.embedding_size < self.config.vocab_size:
+ raise OLMoConfigurationError("embedding size should be at least as big as vocab size")
+ elif self.config.embedding_size % 128 != 0:
+ import warnings
+
+ warnings.warn(
+ "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
+ )
+
+ self.activation_checkpointing_strategy: ActivationCheckpointingStrategy | None = None
+ self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config)
+
+ if not (
+ 0 < self.config.block_group_size <= self.config.n_layers
+ and self.config.n_layers % self.config.block_group_size == 0
+ ):
+ raise OLMoConfigurationError("n layers must be divisible by block group size")
+
+ torch.backends.cuda.enable_flash_sdp(True)
+ torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it
+
+ self.transformer = nn.ModuleDict(
+ dict(
+ wte=nn.Embedding(config.embedding_size or config.vocab_size, config.d_model, device=config.init_device),
+ emb_drop=Dropout(config.embedding_dropout),
+ ln_f=LayerNorm.build(config),
+ )
+ )
+
+ blocks = [CoatOLMoBlock.build(i, config, qargs, self.__cache) for i in range(config.n_layers)]
+ if self.config.block_group_size > 1:
+ block_groups = [
+ CoatOLMoBlockGroup(config, i, blocks[i : i + config.block_group_size])
+ for i in range(0, config.n_layers, config.block_group_size)
+ ]
+ self.transformer.update({"block_groups": nn.ModuleList(block_groups)})
+ else:
+ self.transformer.update({"blocks": nn.ModuleList(blocks)})
+
+ if not (self.config.alibi or self.config.rope):
+ self.transformer.update(
+ {"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
+ )
+ if not config.weight_tying:
+ self.transformer.update(
+ {
+ "ff_out": nn.Linear(
+ config.d_model,
+ config.embedding_size or config.vocab_size,
+ bias=config.include_bias,
+ device=config.init_device,
+ )
+ }
+ )
+ if config.embedding_layer_norm:
+ self.transformer.update({"emb_norm": LayerNorm.build(config)})
+
+ # When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
+ if init_params and self.config.init_device != "meta":
+ self.reset_parameters()
+ self.__num_fwd_flops: int | None = None
+ self.__num_bck_flops: int | None = None
+
+ # Warm up cache.
+ if self.config.alibi:
+ get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config))
+ self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))
+
+ # Quantize
+ self.quantize_input_before_block = Coat_quantize_bgn(qargs)
+ self.quantize_output_after_block = Coat_quantize_end(qargs)
+
+ set_activation_checkpointing = OLMo.set_activation_checkpointing
+ device = OLMo.device
+ reset_parameters = OLMo.reset_parameters
+ get_alibi_attention_bias = OLMo.get_alibi_attention_bias
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor,
+ input_embeddings: torch.FloatTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ attention_bias: torch.Tensor | None = None,
+ past_key_values: Sequence[tuple[torch.Tensor, torch.Tensor]] | None = None,
+ use_cache: bool = False,
+ last_logits_only: bool = False,
+ output_hidden_states: bool | None = None,
+ doc_lens: torch.Tensor | None = None,
+ max_doc_lens: Sequence[int] | None = None,
+ ) -> OLMoOutput:
+ """
+ :param input_ids: A tensor of shape `(batch_size, seq_len)`.
+ :param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input
+ embeddings. When provided, it is treated as the output of the input embedding layer.
+ :param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates
+ which input IDs are masked. A `1` value in the mask means that
+ the corresponding input ID should *not* be ignored. A `0` means
+ that the corresponding input ID is masked.
+
+ This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
+ library.
+ :param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
+ `(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
+ to introduce causal or other biases.
+
+ If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
+ indicates that the i-th element in the sequence is allowed to attend to the j-th
+ element in the sequence.
+
+ If the tensor is a float tensor, it will just be added to the attention
+ scores before the softmax.
+
+ The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
+ :param past_key_values: Pre-computed keys and values for each attention block.
+ Can be used to speed up sequential decoding. The `input_ids` which have
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
+ :param use_cache: If `True`, return key and value tensors for each block.
+ :param last_logits_only: If `True`, only compute the logits for the last token of each sequence.
+ This can speed up decoding when you only care about the next token.
+ :param doc_lens: Document lengths to use in attention for intra-document masking.
+ Shape `(batch_size, max_docs)`.
+ :param max_doc_lens: Maximum document length for each instance in the batch.
+ """
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else False
+
+ if past_key_values:
+ assert len(past_key_values) == self.config.n_layers
+
+ batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
+ if past_key_values is None:
+ past_length = 0
+ else:
+ past_length = past_key_values[0][0].size(-2)
+
+ max_doc_len: int | None = None
+ cu_doc_lens: torch.Tensor | None = None
+ if doc_lens is not None and max_doc_lens is not None:
+ max_doc_len = max(max_doc_lens)
+ cu_doc_lens = get_cumulative_document_lengths(doc_lens)
+
+ # Get embeddings of input.
+ # shape: (batch_size, seq_len, d_model)
+ x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
+
+ # Apply embedding layer norm.
+ if self.config.embedding_layer_norm:
+ x = self.transformer.emb_norm(x)
+
+ if not (self.config.alibi or self.config.rope):
+ # Get positional embeddings.
+ # shape: (1, seq_len)
+ pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
+ # shape: (1, seq_len, d_model)
+ pos_emb = self.transformer.wpe(pos) # type: ignore
+ x = pos_emb + x
+
+ # Apply dropout.
+ # shape: (batch_size, seq_len, d_model)
+ x = self.transformer.emb_drop(x) # type: ignore
+
+ # Transform the attention mask into what the blocks expect.
+ if attention_mask is not None:
+ # shape: (batch_size, 1, 1, seq_len)
+ attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
+ attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
+
+ # Merge attention mask with attention bias.
+ if (
+ attention_bias is not None
+ or attention_mask is not None
+ or self.config.alibi
+ # NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
+ # with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
+ # scores correctly.
+ or past_key_values is not None
+ ):
+ if attention_bias is None and self.config.alibi:
+ attention_bias = get_causal_attention_bias(
+ self.__cache, past_length + seq_len, x.device
+ ) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
+ elif attention_bias is None:
+ attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
+ elif attention_bias.dtype in (torch.int8, torch.bool):
+ attention_bias = attention_bias.to(dtype=torch.float)
+ attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
+
+ # Transform to the right shape and data type.
+ mask_len = seq_len
+ if attention_mask is not None:
+ mask_len = attention_mask.shape[-1]
+ elif past_key_values is not None:
+ mask_len = past_key_values[0][0].shape[-2] + seq_len
+ attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)
+
+ # Add in the masking bias.
+ if attention_mask is not None:
+ attention_bias = attention_bias + attention_mask
+ # Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf.
+ # `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead
+ # it can produce NaNs.
+ ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)
+
+ attn_key_values: list[tuple[torch.Tensor, torch.Tensor]] | None = [] if use_cache else None
+
+ # decoder layers
+ all_hidden_states = []
+
+ # Prepare the input for COAT decoderlayer
+ x, qx, sx = self.quantize_input_before_block(x)
+
+ # Apply blocks one-by-one.
+ if self.config.block_group_size == 1:
+ for block_idx, block in enumerate(self.transformer.blocks):
+ if output_hidden_states:
+ # add hidden states
+ all_hidden_states.append(x)
+
+ layer_past = None if past_key_values is None else past_key_values[block_idx]
+ if should_checkpoint_block(self.activation_checkpointing_strategy, block_idx):
+ # shape: (batch_size, seq_len, d_model)
+ x, qx, sx, cache = self._activation_checkpoint_fn(
+ block,
+ x,
+ qx,
+ sx,
+ attention_bias=attention_bias,
+ layer_past=layer_past,
+ use_cache=use_cache,
+ max_doc_len=max_doc_len,
+ cu_doc_lens=cu_doc_lens,
+ )
+ else:
+ # shape: (batch_size, seq_len, d_model)
+ x, qx, sx, cache = block(
+ x,
+ qx,
+ sx,
+ attention_bias=attention_bias,
+ layer_past=layer_past,
+ use_cache=use_cache,
+ max_doc_len=max_doc_len,
+ cu_doc_lens=cu_doc_lens,
+ )
+
+ if attn_key_values is not None:
+ assert cache is not None
+ attn_key_values.append(cache)
+ else:
+ for group_idx, block_group in enumerate(self.transformer.block_groups):
+ if output_hidden_states:
+ # add hidden states
+ all_hidden_states.append(x)
+
+ layers_past = (
+ None
+ if past_key_values is None
+ else past_key_values[
+ group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
+ ]
+ )
+ x, cache = block_group(
+ x,
+ attention_bias=attention_bias,
+ layers_past=layers_past,
+ use_cache=use_cache,
+ max_doc_len=max_doc_len,
+ cu_doc_lens=cu_doc_lens,
+ )
+ if attn_key_values is not None:
+ assert cache is not None
+ attn_key_values.extend(cache)
+
+ # Summarize the output of the Decoder Layer
+ x = self.quantize_output_after_block(x, qx, sx)
+
+ if last_logits_only:
+ # shape: (batch_size, 1, d_model)
+ x = x[:, -1, :].unsqueeze(1)
+
+ # Apply final layer norm.
+ # shape: (batch_size, seq_len or 1, d_model)
+ x = self.transformer.ln_f(x) # type: ignore
+ if output_hidden_states:
+ # add final hidden state post-final-layernorm, following HuggingFace's convention
+ all_hidden_states.append(x)
+
+ # Get logits.
+ # shape: (batch_size, seq_len or 1, vocab_size)
+ if self.config.weight_tying:
+ logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
+ else:
+ logits = self.transformer.ff_out(x) # type: ignore
+ if self.config.scale_logits:
+ logits.mul_(1 / math.sqrt(self.config.d_model))
+
+ return OLMoOutput(
+ logits=logits,
+ attn_key_values=attn_key_values,
+ hidden_states=tuple(all_hidden_states) if output_hidden_states else None,
+ )
+
+ def get_fsdp_wrap_policy(self, wrap_strategy: FSDPWrapStrategy | None = None):
+ if wrap_strategy is None:
+ return None
+
+ # The 'recurse' mode for the wrap function does not behave like you'd expect.
+ # Even if we return False, it may still recurse because PyTorch does what it wants,
+ # not what you want. This causes issues when, for example, we want to wrap 'ff_out' (a linear layer)
+ # but not other linear layers within a block.
+ # So we have to explicitly tell PyTorch which linear layers to wrap, and we also just
+ # return True in 'recurse' mode for simplicity.
+ size_based_module_to_wrap = {self.transformer.wte}
+ if hasattr(self.transformer, "ff_out"):
+ size_based_module_to_wrap.add(self.transformer.ff_out)
+
+ if wrap_strategy == FSDPWrapStrategy.by_block:
+
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
+ del nonwrapped_numel
+ wrap = isinstance(module, CoatOLMoBlock)
+ if recurse:
+ return True
+ else:
+ return wrap
+
+ return fsdp_wrap_fn
+ elif wrap_strategy == FSDPWrapStrategy.by_block_and_size:
+
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
+ del nonwrapped_numel
+ wrap = isinstance(module, (CoatOLMoBlock,)) or module in size_based_module_to_wrap
+ if recurse:
+ return True
+ else:
+ return wrap
+
+ return fsdp_wrap_fn
+ elif wrap_strategy == FSDPWrapStrategy.by_block_group:
+ if self.config.block_group_size <= 1:
+ raise OLMoConfigurationError(
+ "'by_block_group' FSDP wrapping strategy requires block group size greater than 1"
+ )
+
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
+ del nonwrapped_numel
+ wrap = isinstance(module, CoatOLMoBlockGroup)
+ if recurse:
+ return True
+ else:
+ return wrap
+
+ return fsdp_wrap_fn
+ elif wrap_strategy == FSDPWrapStrategy.by_block_group_and_size:
+ if self.config.block_group_size <= 1:
+ raise OLMoConfigurationError(
+ "'by_block_group_and_size' FSDP wrapping strategy requires block group size greater than 1"
+ )
+
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
+ del nonwrapped_numel
+ wrap = isinstance(module, (CoatOLMoBlockGroup,)) or module in size_based_module_to_wrap
+ if recurse:
+ return True
+ else:
+ return wrap
+
+ return fsdp_wrap_fn
+ elif wrap_strategy == FSDPWrapStrategy.size_based:
+ from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
+
+ return size_based_auto_wrap_policy
+ elif wrap_strategy in {
+ FSDPWrapStrategy.one_in_two,
+ FSDPWrapStrategy.one_in_three,
+ FSDPWrapStrategy.one_in_four,
+ FSDPWrapStrategy.one_in_five,
+ }:
+ c = {
+ FSDPWrapStrategy.one_in_two: 2,
+ FSDPWrapStrategy.one_in_three: 3,
+ FSDPWrapStrategy.one_in_four: 4,
+ FSDPWrapStrategy.one_in_five: 5,
+ }[wrap_strategy]
+
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
+ del nonwrapped_numel
+ wrap = isinstance(module, CoatOLMoBlock) and module.layer_id % c == 0
+ if recurse:
+ return True
+ else:
+ return wrap
+
+ return fsdp_wrap_fn
+ else:
+ raise NotImplementedError(wrap_strategy)
+
+ num_params = OLMo.num_params
+
+ @property
+ def num_fwd_flops(self):
+ if self.__num_fwd_flops:
+ return self.__num_fwd_flops
+
+ # embedding table is just a lookup in the forward pass
+ n_params = self.num_params(include_embedding=False)
+ # the number of parameters is approximately the number of multiply-accumulates (MAC) in the network
+ # each MAC has 2 FLOPs - we multiply by 2 ie 2 * n_param
+ # this gets us FLOPs / token
+ params_flops_per_token = 2 * n_params
+ # there are 2 FLOPS per mac; there is A=Q*K^T and out=A*V ops (ie mult by 2)
+ attn_flops_per_token = self.config.n_layers * 2 * 2 * (self.config.d_model * self.config.max_sequence_length)
+ self.__num_fwd_flops = params_flops_per_token + attn_flops_per_token
+ return self.__num_fwd_flops
+
+ @property
+ def num_bck_flops(self):
+ if self.__num_bck_flops:
+ return self.__num_bck_flops
+
+ n_params = self.num_params()
+ params_flops_per_token = 4 * n_params
+ attn_flops_per_token = self.config.n_layers * 8 * (self.config.d_model * self.config.max_sequence_length)
+ self.__num_bck_flops = params_flops_per_token + attn_flops_per_token
+ return self.__num_bck_flops
+
+ generate = OLMo.generate
+
+ @classmethod
+ def from_checkpoint(
+ cls, checkpoint_dir: PathOrStr, device: str = "cpu", checkpoint_type: CheckpointType | None = None
+ ) -> CoatOLMo:
+ """
+ Load an OLMo model from a checkpoint.
+ """
+ from olmo.util import resource_path
+
+ # Guess checkpoint type.
+ if checkpoint_type is None:
+ try:
+ if resource_path(checkpoint_dir, "model.pt").is_file():
+ checkpoint_type = CheckpointType.unsharded
+ else:
+ checkpoint_type = CheckpointType.sharded
+ except FileNotFoundError:
+ checkpoint_type = CheckpointType.sharded
+
+ # Load config.
+ config_path = resource_path(checkpoint_dir, "config.yaml")
+ model_config = ModelConfig.load(config_path, key="model", validate_paths=False)
+
+ if checkpoint_type == CheckpointType.unsharded:
+ # Initialize model (always on CPU to start with so we don't run out of GPU memory).
+ model_config.init_device = "cpu"
+ model = CoatOLMo(model_config)
+
+ # Load state dict directly to target device.
+ state_dict_path = resource_path(checkpoint_dir, "model.pt")
+ state_dict = torch.load(state_dict_path, map_location="cpu")
+ model.load_state_dict(model._make_state_dict_compatible(state_dict)[0])
+ model = model.to(torch.device(device))
+ else:
+ train_config = TrainConfig.load(config_path)
+ if train_config.sharded_checkpointer == ShardedCheckpointerType.olmo_core:
+ from olmo_core.distributed.checkpoint import load_model_and_optim_state # type: ignore
+
+ model_config.init_device = device
+ model = CoatOLMo(model_config)
+ load_model_and_optim_state(checkpoint_dir, model)
+ else:
+ # train_config.sharded_checkpointer == ShardedCheckpointerType.torch_new
+ from olmo.checkpoint import load_model_state
+
+ # Initialize model on target device. In this case the state dict is loaded in-place
+ # so it's not necessary to start on CPU if the target device is a GPU.
+ model_config.init_device = device
+ model = CoatOLMo(model_config)
+
+ # Load state dict in place.
+ load_model_state(checkpoint_dir, model)
+
+ return model.eval()
+
+ def _make_state_dict_compatible(
+ self, state_dict: dict[str, torch.Tensor]
+ ) -> tuple[dict[str, torch.Tensor], dict[str, set[str]]]:
+ """
+ Handles some cases where the state dict is valid yet may need to be transformed in order to
+ be loaded.
+
+ This modifies the state dict in-place and also returns it, along with a mapping of original key
+ names to new key names in cases where the keys were simply renamed. That mapping can be used
+ to make a corresponding optimizer state dict compatible as well.
+ """
+ import re
+ from fnmatch import fnmatch
+
+ new_keys_to_og_keys: dict[str, str] = {}
+
+ # Remove "_fsdp_wrapped_module." prefix from all keys. We don't want this prefix when the model is
+ # not wrapped in FSDP. And when the model is wrapped in FSDP, loading this state dict will still work
+ # fine without the prefixes. This also simplifies the other steps below.
+ for key in list(state_dict.keys()):
+ state_dict[(new_key := key.replace("_fsdp_wrapped_module.", ""))] = state_dict.pop(key)
+ new_keys_to_og_keys[new_key] = key
+
+ # For backwards compatibility prior to fixing https://github.com/allenai/LLM/issues/222
+ if self.config.block_type == BlockType.sequential:
+ for key in list(state_dict.keys()):
+ if fnmatch(key, "transformer.*.norm.weight"):
+ tensor = state_dict.pop(key)
+ state_dict[(new_key := key.replace("norm.weight", "attn_norm.weight"))] = tensor
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
+ state_dict[(new_key := key.replace("norm.weight", "ff_norm.weight"))] = tensor.clone()
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
+ del new_keys_to_og_keys[key]
+ elif fnmatch(key, "transformer.*.norm.bias"):
+ tensor = state_dict.pop(key)
+ state_dict[(new_key := key.replace("norm.bias", "attn_norm.bias"))] = tensor
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
+ state_dict[(new_key := key.replace("norm.bias", "ff_norm.bias"))] = tensor.clone()
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
+ del new_keys_to_og_keys[key]
+
+ # Realquantization will change the place the linear layers happen
+ if self.qargs.use_quantize_model == "coat_real":
+ for key in list(state_dict.keys()):
+ if fnmatch(key, "transformer.blocks.*.att_proj.weight") and "BeforeAttention" not in key:
+ tensor = state_dict.pop(key)
+ state_dict[(new_key := key.replace("att_proj.weight", "BeforeAttention.att_proj.weight"))] = tensor
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
+ del new_keys_to_og_keys[key]
+ elif fnmatch(key, "transformer.blocks.*.attn_out.weight") and "AfterAttention" not in key:
+ tensor = state_dict.pop(key)
+ state_dict[(new_key := key.replace("attn_out.weight", "AfterAttention.attn_out.weight"))] = tensor
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
+ del new_keys_to_og_keys[key]
+ elif fnmatch(key, "transformer.blocks.*.ff_proj.weight") and "MLPResidual" not in key:
+ tensor = state_dict.pop(key)
+ state_dict[(new_key := key.replace("ff_proj.weight", "MLPResidual.ff_proj.weight"))] = tensor
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
+ del new_keys_to_og_keys[key]
+ elif fnmatch(key, "transformer.blocks.*.ff_out.weight") and "MLPResidual" not in key:
+ tensor = state_dict.pop(key)
+ state_dict[(new_key := key.replace("ff_out.weight", "MLPResidual.ff_out.weight"))] = tensor
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
+ del new_keys_to_og_keys[key]
+
+ # For loading a state dict that was saved with a different `block_group_size`.
+ if "transformer.block_groups.0.0.attn_out.weight" in state_dict.keys():
+ state_dict_block_group_size = len(
+ [k for k in state_dict.keys() if fnmatch(k, "transformer.block_groups.0.*.attn_out.weight")]
+ )
+ else:
+ state_dict_block_group_size = 1
+ if self.config.block_group_size != state_dict_block_group_size:
+ log.info(
+ f"Regrouping state dict blocks from group size {state_dict_block_group_size} to "
+ f"group size {self.config.block_group_size}"
+ )
+ # For simplicity we're first going to flatten out the block groups in the state dict (if necessary)
+ # and then (re-)group them into the right block sizes.
+ if state_dict_block_group_size > 1:
+ for key in list(state_dict.keys()):
+ if (m := re.match(r"transformer.block_groups\.(\d+)\.(\d+)\..*", key)) is not None:
+ group_idx, group_block_idx = int(m.group(1)), int(m.group(2))
+ block_idx = (group_idx * state_dict_block_group_size) + group_block_idx
+ state_dict[
+ (
+ new_key := key.replace(
+ f"block_groups.{group_idx}.{group_block_idx}.", f"blocks.{block_idx}."
+ )
+ )
+ ] = state_dict.pop(key)
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key)
+
+ if self.config.block_group_size > 1:
+ # Group the state dict blocks into the right block size.
+ for key in list(state_dict.keys()):
+ if (m := re.match(r"transformer.blocks\.(\d+)\..*", key)) is not None:
+ block_idx = int(m.group(1))
+ group_idx, group_block_idx = (
+ block_idx // self.config.block_group_size,
+ block_idx % self.config.block_group_size,
+ )
+ state_dict[
+ (
+ new_key := key.replace(
+ f"blocks.{block_idx}.", f"block_groups.{group_idx}.{group_block_idx}."
+ )
+ )
+ ] = state_dict.pop(key)
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key)
+
+ og_keys_to_new: dict[str, set[str]] = defaultdict(set)
+ for new_key, og_key in new_keys_to_og_keys.items():
+ og_keys_to_new[og_key].add(new_key)
+
+ return state_dict, og_keys_to_new
diff --git a/llava/model/coat/activation/real_quantization/__init__.py b/llava/model/coat/activation/real_quantization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..61875efb5d6a495fdab36438ea99a8784a88c4cf
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/__init__.py
@@ -0,0 +1,31 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Activation
+# Utils
+from ._dequantize import fp8_dequantize
+from ._division import fp8_division
+from ._division_transpose import fp8_division_transpose
+from ._quantize import fp8_quantize
+from ._quantize_pertensor import fp8_quantize_pertensor
+from ._quantize_pertensor_transpose import fp8_quantize_pertensor_transpose
+from ._transpose import fp8_transpose
+from .add_bwd import fp8_add_Ifp_Ifp_Ofp_Opt
+from .add_fwd import fp8_add_Ifp_Ifp_Ofp_Og16
+
+# Normalization
+from .func_layernorm_noparam import fp8_layernorm_noparam_backward, fp8_layernorm_noparam_forward
+from .func_quantize import Coat_quantize_bgn, Coat_quantize_end
+from .func_rmsnorm import fp8_rmsnorm_backward, fp8_rmsnorm_forward
+from .gelu_bwd import fp8_gelu_backward
+from .gelu_fwd import fp8_gelu_forward
+
+# linear and add
+from .linear import fp8_linear_backward, fp8_linear_forward
+from .mul_bwd import fp8_mul_backward
+from .mul_fwd import fp8_mul_forward
+from .silu_bwd import fp8_silu_backward
+from .silu_fwd import fp8_silu_forward
diff --git a/llava/model/coat/activation/real_quantization/_dequantize.py b/llava/model/coat/activation/real_quantization/_dequantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..4347932f9c3691ffada9d59258337aac75b74942
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/_dequantize.py
@@ -0,0 +1,162 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block
+
+"""Quantize Operator"""
+"""Input uses 1 * 16 group quantization"""
+"""Output uses 1 * 16 group quantization"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_dequantize_kernel(
+ output_ptr, # output
+ input_ptr,
+ input_scale_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ s_input_stride_0,
+ s_input_stride_1, # scale of output stride
+ output_stride_0,
+ output_stride_1, # output stride
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ # input ptr
+ scale_input_ptr = tl.make_block_ptr(
+ base=input_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input_stride_0, s_input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ scale_input = tl.load(scale_input_ptr)
+
+ input = input.to(tl.float32)
+ scale_input = scale_input.to(tl.float32)
+
+ # Dequantize and gelu calculation
+ scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1))
+ input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+ output = input * scale_input
+
+ output = tl.reshape(output, (BLOCK_M, BLOCK_N))
+ output = output.to(output_ptr.dtype.element_ty)
+
+ # debug
+ # gelu_output = input
+ # scale_output = scale_input
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ tl.store(output_block_ptr, output, boundary_check=(0, 1))
+
+
+def fp8_dequantize(x, s_x, QB):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+ s_x = s_x.reshape(-1, s_x.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x.shape
+ SN = N // QB
+
+ y = torch.empty_like(x, dtype=torch.bfloat16)
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_dequantize_kernel[grid](
+ y,
+ x,
+ s_x,
+ M,
+ N,
+ SN,
+ QB,
+ x.stride(0),
+ x.stride(1),
+ s_x.stride(0),
+ s_x.stride(1),
+ y.stride(0),
+ y.stride(1),
+ )
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+
+ return y
diff --git a/llava/model/coat/activation/real_quantization/_division.py b/llava/model/coat/activation/real_quantization/_division.py
new file mode 100644
index 0000000000000000000000000000000000000000..335ff1a67795dc858a740c6618466ff79fb554cb
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/_division.py
@@ -0,0 +1,212 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block
+
+"""Quantize and Transpose Operator"""
+"""Input uses 1 * 16 group quantization"""
+"""Output uses 1 * 16 group quantization"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_division_kernel(
+ output_ptr, # output
+ input_ptr,
+ input_scale_ptr, # input
+ noise_ptr, # noise for stochastic
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max,
+ e_bit: tl.constexpr,
+ m_bit: tl.constexpr, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ output_stride_0,
+ output_stride_1, # output stride
+ SCALE_MIN_THRES: tl.constexpr, # We do not use it since we believe SCALE_MIN_THRES should be used in previous kernel when calculating scaling factor
+ STOCHASTIC: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ input = input.to(tl.float32)
+ scale_output = tl.load(input_scale_ptr)
+ scale_output = scale_output.to(tl.float32)
+
+ output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+
+ # Quantize Scale calculation
+ # Quantize
+ output = tl.fdiv(output, scale_output)
+ output = tl.reshape(output, (BLOCK_M, BLOCK_N))
+
+ if STOCHASTIC:
+ # noise_block_ptr = tl.make_block_ptr(
+ # base=noise_ptr,
+ # shape=(M, N),
+ # strides=(input_stride_0, input_stride_1),
+ # offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ # block_shape=(BLOCK_M, BLOCK_N),
+ # order=(1, 0)
+ # )
+ # noise = tl.load(noise_block_ptr)
+
+ offs_m = pid_dim0 * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_dim1 * BLOCK_N + tl.arange(0, BLOCK_N)
+ noise_offset = offs_m[:, None] * input_stride_0 + offs_n[None, :] * input_stride_1
+ noise = tl.rand(0, noise_offset)
+
+ output = _stochastic_rounding(output, noise, e_bit, m_bit)
+
+ output = output.to(output_ptr.type.element_ty)
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ tl.store(output_block_ptr, output, boundary_check=(0, 1))
+
+
+@triton.jit
+def _stochastic_rounding(output, noise, e_bit: tl.constexpr, m_bit: tl.constexpr):
+ subnormal_min = tl.exp2(2 - tl.exp2(e_bit - 1) - m_bit)
+ # subnormal_should_be = tl.exp2(2 - tl.exp2(e_bit) - 1)
+
+ output_int32 = tl.cast(output, tl.int32, bitcast=True)
+ output_int32 = output_int32 & 0x7F800000
+ output_float32 = tl.cast(output_int32, tl.float32, bitcast=True)
+ output_exp = tl.maximum(output_float32, subnormal_min)
+
+ noise_rescale = tl.exp2(m_bit) + (output_exp == subnormal_min) * (
+ 1 - tl.exp2(m_bit)
+ ) # 2^m_bit for normal, 1 for subnormal
+
+ noise = output_exp * noise / noise_rescale
+ sign = 1 - 2 * libdevice.signbit(output)
+ output = tl.abs(output) + noise
+
+ minmax_ratio = 2 + (output_exp == subnormal_min) * (tl.exp2(m_bit) - 2) # 2 for normal, and 2^M for subnormal
+ output = sign * tl.clamp(output, min=output_exp, max=minmax_ratio * output_exp)
+
+ return output
+
+
+def fp8_division(x, QB, fp8type, s_y=None, stochastic=False):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+
+ if stochastic:
+ # noise = torch.zeros_like(x, dtype=torch.float32).uniform_(-0.5, 0.5)
+ noise = None
+ else:
+ noise = None
+
+ # defining the input and output tensor
+ M, N = x.shape
+ SN = N // QB
+
+ if isinstance(fp8type, str):
+ fp8type = convert_str_to_fp8[fp8type]
+
+ y = torch.empty_like(x, dtype=fp8type)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+ e_bit, m_bit = convert_fp8_to_embit[fp8type]
+
+ if s_y is None:
+ s_y = (x.abs().max() + SCALE_MIN_THRES) / fp8MaxValue
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_division_kernel[grid](
+ y,
+ x,
+ s_y,
+ noise,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ e_bit,
+ m_bit,
+ x.stride(0),
+ x.stride(1),
+ y.stride(0),
+ y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ STOCHASTIC=stochastic,
+ )
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+
+ return y, s_y # y_t is expected to be 2D tensor
diff --git a/llava/model/coat/activation/real_quantization/_division_transpose.py b/llava/model/coat/activation/real_quantization/_division_transpose.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fbabaa31e4f47482c98d69431a9dfae488a2f15
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/_division_transpose.py
@@ -0,0 +1,215 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from ._division import _stochastic_rounding
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block
+
+"""Division and Transpose Operator"""
+"""Input uses full-precision/BF16"""
+"""Output uses per tensor quantization"""
+"""Output_t uses per tensor quantization and is transposed, but is flattened to 2D"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(), # triton.Config({'BLOCK_M': 1, 'BLOCK_N': 16}, num_stages=4, num_warps=1,)
+ # configs=[triton.Config({'BLOCK_M': 1, 'BLOCK_N': 16}, num_stages=4, num_warps=1,)], #
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_division_transpose_kernel(
+ output_ptr,
+ output_t_ptr, # output
+ input_ptr,
+ input_scale_ptr, # input
+ noise_ptr, # noise for stochastic
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max,
+ e_bit,
+ m_bit, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ output_stride_0,
+ output_stride_1, # output stride
+ output_t_stride_0,
+ output_t_stride_1, # output stride
+ SCALE_MIN_THRES: tl.constexpr, # We do not use it since we believe SCALE_MIN_THRES should be used in previous kernel when calculating scaling factor
+ STOCHASTIC: tl.constexpr,
+ ONLY_TRANSPOSED: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ input = input.to(tl.float32)
+ scale_output = tl.load(input_scale_ptr)
+ scale_output = scale_output.to(tl.float32)
+
+ output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+
+ # Quantize Scale calculation
+ # Quantize
+ output = tl.fdiv(output, scale_output)
+ output = tl.reshape(output, (BLOCK_M, BLOCK_N))
+
+ if STOCHASTIC:
+ # noise_block_ptr = tl.make_block_ptr(
+ # base=noise_ptr,
+ # shape=(M, N),
+ # strides=(input_stride_0, input_stride_1),
+ # offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ # block_shape=(BLOCK_M, BLOCK_N),
+ # order=(1, 0)
+ # )
+ # noise = tl.load(noise_block_ptr)
+
+ offs_m = pid_dim0 * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_dim1 * BLOCK_N + tl.arange(0, BLOCK_N)
+ noise_offset = offs_m[:, None] * input_stride_0 + offs_n[None, :] * input_stride_1
+ noise = tl.rand(0, noise_offset)
+
+ output = _stochastic_rounding(output, noise, e_bit, m_bit)
+
+ output = output.to(output_ptr.type.element_ty)
+ # tl.device_print("3: ", output)
+ output_t = tl.trans(output)
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ output_t_block_ptr = tl.make_block_ptr(
+ base=output_t_ptr,
+ shape=(N, M),
+ strides=(output_t_stride_0, output_t_stride_1),
+ offsets=(pid_dim1 * BLOCK_N, pid_dim0 * BLOCK_M),
+ block_shape=(BLOCK_N, BLOCK_M),
+ order=(1, 0),
+ )
+ if not ONLY_TRANSPOSED:
+ tl.store(output_block_ptr, output, boundary_check=(0, 1))
+ tl.store(output_t_block_ptr, output_t, boundary_check=(0, 1))
+
+
+def fp8_division_transpose(x, QB, fp8type, s_y=None, stochastic=False, only_transposed=False):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+
+ if stochastic:
+ # noise = torch.empty_like(x, dtype=torch.float32).uniform_(-0.5, 0.5)
+ noise = None
+ else:
+ noise = None
+
+ # defining the input and output tensor
+ M, N = x.shape
+ SN = N // QB
+
+ if isinstance(fp8type, str):
+ fp8type = convert_str_to_fp8[fp8type]
+
+ y = torch.empty_like(x, dtype=fp8type)
+ y_t = torch.empty((N, M), dtype=fp8type, device=x.device)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+ e_bit, m_bit = convert_fp8_to_embit[fp8type]
+
+ if s_y is None:
+ # print("Warning: do not specify s_y in fp8_division_transpose")
+ s_y = (x.abs().max() + SCALE_MIN_THRES) / fp8MaxValue
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_division_transpose_kernel[grid](
+ y,
+ y_t,
+ x,
+ s_y,
+ noise,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ e_bit,
+ m_bit,
+ x.stride(0),
+ x.stride(1),
+ y.stride(0),
+ y.stride(1),
+ y_t.stride(0),
+ y_t.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ STOCHASTIC=stochastic,
+ ONLY_TRANSPOSED=only_transposed,
+ )
+
+ if not only_transposed:
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+
+ return y, s_y, y_t # y_t is expected to be 2D tensor
+ else:
+ return y_t, s_y
diff --git a/llava/model/coat/activation/real_quantization/_memory_io.py b/llava/model/coat/activation/real_quantization/_memory_io.py
new file mode 100644
index 0000000000000000000000000000000000000000..b04d2a7afcf749b634cbd5d88c6d99b42030d5ff
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/_memory_io.py
@@ -0,0 +1,180 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+CONST_BLOCK = 32
+
+# The kernel with 1 load operation and 4 store operation
+def get_configs_io_block():
+ configs = []
+ for nstages in [3, 4, 5, 6]:
+ for block_m in [32, 64, 128]:
+ for block_n in [32, 64, 128]:
+ for nwarps in [4, 8, 16, 32]:
+ configs.append(
+ triton.Config(
+ {"BLOCK_M": block_m, "BLOCK_N": block_n},
+ num_stages=nstages,
+ num_warps=nwarps,
+ )
+ )
+ return configs
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.jit
+def bench_memory_io_kernel_forward(
+ output_ptr,
+ input_ptr,
+ M,
+ N,
+ B: tl.constexpr,
+ input_stride_0,
+ input_stride_1,
+ output_stride_0,
+ output_stride_1,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+):
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_M = tl.cdiv(M, BLOCK_M)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ input = input.to(tl.float32)
+
+ output = input * 2
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ output = output.to(output_ptr.type.element_ty)
+ tl.store(output_block_ptr, output, boundary_check=(0, 1))
+
+
+def bench_memory_io_forward(x, B):
+ # defining the input and output tensor
+ M, N = x.shape
+
+ y = torch.empty_like(x, dtype=x.dtype)
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ bench_memory_io_kernel_forward[grid](
+ y,
+ x,
+ M,
+ N,
+ B,
+ x.stride(0),
+ x.stride(1),
+ y.stride(0),
+ y.stride(1),
+ )
+ return y
+
+
+configs = []
+for SL in [8192]:
+ configs.append(
+ triton.testing.Benchmark( # test different matrix size influence
+ x_names=["CDIM"],
+ x_vals=[1024, 2048, 4096, 8192],
+ line_arg="dtype",
+ line_vals=[torch.int8, torch.float16, torch.float32],
+ line_names=["float8", "float16", "float32"],
+ styles=[("blue", "-"), ("green", "-"), ("red", "-")],
+ ylabel="time-cost",
+ plot_name=f"INT8GELU",
+ args={"SL": SL, "B": CONST_BLOCK, "provider": "triton", "mode": "time-consuming"},
+ )
+ )
+
+
+@triton.testing.perf_report(configs)
+def bench_load_store(
+ SL, CDIM, B, provider, dtype, mode="forward"
+): # I only use triton as the provider, and mode when benchmarking
+ # create data
+ x = torch.randn(SL, CDIM, dtype=torch.float32).cuda()
+ x = x.to(dtype)
+
+ quantiles = [0.5, 0.2, 0.8]
+ # utility functions
+ if provider == "triton":
+
+ def y_fwd():
+ bench_memory_io_forward(x, B)
+
+ if provider == "torch":
+ torch_gelu = torch.nn.GELU()
+
+ def y_fwd():
+ return torch_gelu(x)
+
+ # forward pass
+ if mode == "time-consuming":
+ convert_func = lambda ms: ms
+ ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=100)
+ # backward pass
+ if mode == "gbps":
+ convert_func = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
+ ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=100)
+ return convert_func(ms), convert_func(max_ms), convert_func(min_ms)
+
+
+if __name__ == "__main__":
+ torch.manual_seed(0)
+ torch.set_printoptions(precision=8, linewidth=1600, sci_mode=False, edgeitems=3)
+ bench_load_store.run(print_data=True)
diff --git a/llava/model/coat/activation/real_quantization/_quantize.py b/llava/model/coat/activation/real_quantization/_quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1385b7dc42f25f0505a42cfecad48d9fa6eaa51
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/_quantize.py
@@ -0,0 +1,176 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block
+
+"""Quantize Operator"""
+"""Input uses 1 * 16 group quantization"""
+"""Output uses 1 * 16 group quantization"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_quantize_kernel(
+ output_ptr,
+ output_scale_ptr, # output
+ input_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ output_stride_0,
+ output_stride_1, # output stride
+ s_output_stride_0,
+ s_output_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ input = input.to(tl.float32)
+
+ output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+
+ # Quantize Scale calculation
+ abs_output = tl.abs(output)
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
+
+ # Quantize
+ output = tl.fdiv(output, scale_output)
+
+ output = output.to(output_ptr.type.element_ty)
+
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
+ output = tl.reshape(output, (BLOCK_M, BLOCK_N))
+
+ # debug
+ # gelu_output = input
+ # scale_output = scale_input
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output_ptr = tl.make_block_ptr(
+ base=output_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output_stride_0, s_output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ tl.store(output_block_ptr, output, boundary_check=(0, 1))
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
+
+
+def fp8_quantize(x, QB, fp8type):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x.shape
+ SN = N // QB
+
+ if isinstance(fp8type, str):
+ fp8type = convert_str_to_fp8[fp8type]
+ y = torch.empty_like(x, dtype=fp8type)
+ s_y = torch.empty((M, SN), dtype=torch.bfloat16, device=x.device)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_quantize_kernel[grid](
+ y,
+ s_y,
+ x,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x.stride(0),
+ x.stride(1),
+ y.stride(0),
+ y.stride(1),
+ s_y.stride(0),
+ s_y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ s_y = s_y.reshape(BS, -1, s_y.shape[-1])
+
+ return y, s_y
diff --git a/llava/model/coat/activation/real_quantization/_quantize_pertensor.py b/llava/model/coat/activation/real_quantization/_quantize_pertensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..46cadbc6e77e2e34c1fafec968aa17c5621da034
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/_quantize_pertensor.py
@@ -0,0 +1,152 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from ._division import fp8_division
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_str_to_fp8, get_configs_io_block
+
+"""Per Tensor Quantize Operator"""
+"""Input uses full precision"""
+"""Output uses per tensor quantization"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_quantize_pertensor_kernel(
+ output_scale_ptr, # output
+ input_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ s_output_stride_0,
+ s_output_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ input = input.to(tl.float32)
+
+ output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+
+ # Quantize Scale calculation
+ abs_output = tl.abs(output)
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
+
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
+
+ scale_output_ptr = tl.make_block_ptr(
+ base=output_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output_stride_0, s_output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
+
+
+def fp8_quantize_pertensor(x, QB, fp8type, stochastic=False):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x.shape
+ SN = N // QB
+
+ fp8type = convert_str_to_fp8[fp8type]
+ s_y = torch.empty((M, SN), dtype=torch.bfloat16, device=x.device)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_quantize_pertensor_kernel[grid](
+ s_y,
+ x,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x.stride(0),
+ x.stride(1),
+ s_y.stride(0),
+ s_y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ s_y_max = s_y.max()
+ y, s_y_max = fp8_division(x, QB, fp8type, s_y_max, stochastic=stochastic) # reuse the floating point output y1
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ s_y = s_y.reshape(BS, -1, s_y.shape[-1])
+
+ return y, s_y_max, s_y
diff --git a/llava/model/coat/activation/real_quantization/_quantize_pertensor_transpose.py b/llava/model/coat/activation/real_quantization/_quantize_pertensor_transpose.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e91d01445c99928aa7f79227cdbc7706fee1981
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/_quantize_pertensor_transpose.py
@@ -0,0 +1,155 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from ._division_transpose import fp8_division_transpose
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_str_to_fp8, get_configs_io_block
+
+"""Per Tensor Quantize and Transpose Operator"""
+"""Input uses floating point tensor"""
+"""Output uses per-tensor quantization, returns a non-transpose version and a transpose version"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_quantize_pertensor_transpose_kernel(
+ output_scale_ptr, # output
+ input_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ s_output_stride_0,
+ s_output_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ input = input.to(tl.float32)
+
+ output = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+
+ # Quantize Scale calculation
+ abs_output = tl.abs(output)
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
+
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
+
+ scale_output_ptr = tl.make_block_ptr(
+ base=output_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output_stride_0, s_output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
+
+
+def fp8_quantize_pertensor_transpose(x, QB, fp8type, transpose_output_2d=False, stochastic=False):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x.shape
+ SN = N // QB
+
+ fp8type = convert_str_to_fp8[fp8type]
+ s_y = torch.empty((M, SN), dtype=torch.bfloat16, device=x.device)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_quantize_pertensor_transpose_kernel[grid](
+ s_y,
+ x,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x.stride(0),
+ x.stride(1),
+ s_y.stride(0),
+ s_y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ s_y_max = s_y.max()
+ qy, s_y_max, qy_t = fp8_division_transpose(
+ x, QB, fp8type, s_y_max, stochastic=stochastic
+ ) # Stochastic Rounding happens here
+
+ # Recover 2D to 3D
+ if batched:
+ qy = qy.reshape(BS, -1, qy.shape[-1])
+ if not transpose_output_2d:
+ qy_t = qy_t.reshape(BS, -1, qy_t.shape[-1])
+
+ return qy, s_y_max, qy_t # y_t is expected to be 2D tensor
diff --git a/llava/model/coat/activation/real_quantization/_transpose.py b/llava/model/coat/activation/real_quantization/_transpose.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b0b8ec2dc499d2bcd661e619d9f54e91daeac0c
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/_transpose.py
@@ -0,0 +1,121 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from .common import get_configs_io_block
+
+"""Quantize Operator"""
+"""Input uses 1 * 16 group quantization"""
+"""Output uses 1 * 16 group quantization"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.jit
+def _fp8_transpose_kernel(
+ output_ptr, # output
+ input_ptr, # input
+ M,
+ N, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ output_stride_0,
+ output_stride_1, # output stride
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+
+ output = tl.trans(input)
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(N, M),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim1 * BLOCK_N, pid_dim0 * BLOCK_M),
+ block_shape=(BLOCK_N, BLOCK_M),
+ order=(1, 0),
+ )
+
+ tl.store(output_block_ptr, output, boundary_check=(0, 1))
+
+
+def fp8_transpose(x, transpose_output_2d=False):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x.shape
+
+ y = torch.empty((N, M), dtype=x.dtype, device=x.device)
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_transpose_kernel[grid](
+ y,
+ x,
+ M,
+ N,
+ x.stride(0),
+ x.stride(1),
+ y.stride(0),
+ y.stride(1),
+ )
+
+ # Recover 2D to 3D
+ if batched and not transpose_output_2d:
+ y = y.reshape(BS, -1, y.shape[-1])
+
+ return y
diff --git a/llava/model/coat/activation/real_quantization/add_bwd.py b/llava/model/coat/activation/real_quantization/add_bwd.py
new file mode 100644
index 0000000000000000000000000000000000000000..5951afec125048d8549fd1be0ae204df1a6b222a
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/add_bwd.py
@@ -0,0 +1,205 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from ._division import fp8_division
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_str_to_fp8, get_configs_io_block
+
+"""Element-wise Add, useful for backward"""
+"""Input1 (Residual) uses full-precision/BF16"""
+"""Input2 (Backbone) uses full-precision/BF16"""
+"""Output1 uses full-precision/BF16"""
+"""Output2 uses per-tensor quantization"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_add_Ifp_Ifp_Ofp_Opt_kernel(
+ output1_ptr, # output
+ output2_scale_ptr,
+ input1_ptr, # input
+ input2_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input1_stride_0,
+ input1_stride_1, # input1 stride
+ input2_stride_0,
+ input2_stride_1, # input2 stride
+ output1_stride_0,
+ output1_stride_1, # output stride
+ s_output2_stride_0,
+ s_output2_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # --- The first input ---
+ input1_block_ptr = tl.make_block_ptr(
+ base=input1_ptr,
+ shape=(M, N),
+ strides=(input1_stride_0, input1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input1 = tl.load(input1_block_ptr)
+ input1 = input1.to(tl.float32)
+ input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB))
+
+ # --- The second input ---
+ input2_block_ptr = tl.make_block_ptr(
+ base=input2_ptr,
+ shape=(M, N),
+ strides=(input2_stride_0, input2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input2 = tl.load(input2_block_ptr)
+ input2 = input2.to(tl.float32)
+ input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB))
+
+ # Actual Calculation of Add
+ add_output = input1 + input2
+
+ # Quantize the grad 1 - Scale calculation
+ abs_add_output = tl.abs(add_output)
+ max_val = tl.max(abs_add_output, axis=2) + SCALE_MIN_THRES
+ scale_output2 = max_val / fp8_max
+ scale_output2 = tl.reshape(scale_output2, (BLOCK_M, BLOCK_SN, 1))
+
+ # save the fp add output
+ fp_add_output = add_output.to(output1_ptr.type.element_ty)
+ fp_add_output = tl.reshape(fp_add_output, (BLOCK_M, BLOCK_N))
+
+ # pointers
+ output1_block_ptr = tl.make_block_ptr(
+ base=output1_ptr,
+ shape=(M, N),
+ strides=(output1_stride_0, output1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ tl.store(output1_block_ptr, fp_add_output, boundary_check=(0, 1))
+
+ # Quantize
+ scale_output2 = scale_output2.to(output2_scale_ptr.type.element_ty)
+ scale_output2 = tl.reshape(scale_output2, (BLOCK_M, BLOCK_SN))
+
+ # pointers
+ scale_output2_ptr = tl.make_block_ptr(
+ base=output2_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output2_stride_0, s_output2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+ tl.store(scale_output2_ptr, scale_output2, boundary_check=(0, 1))
+
+
+def fp8_add_Ifp_Ifp_Ofp_Opt(x1, x2, QB, fp8type, stochastic=False): # suppose x1 is full precision or BF16
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x1.shape) == 3:
+ assert len(x2.shape) == 3
+ batched = True
+ BS = x1.shape[0]
+ x1 = x1.reshape(-1, x1.shape[-1])
+ x2 = x2.reshape(-1, x2.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x1.shape
+ SN = N // QB
+ assert x1.shape == x2.shape
+
+ if isinstance(fp8type, str):
+ fp8type = convert_str_to_fp8[fp8type]
+ y1 = torch.empty_like(x1, dtype=torch.bfloat16)
+ s_y2 = torch.empty((M, SN), dtype=torch.bfloat16, device=x2.device)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_add_Ifp_Ifp_Ofp_Opt_kernel[grid](
+ y1,
+ s_y2,
+ x1,
+ x2,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x1.stride(0),
+ x1.stride(1),
+ x2.stride(0),
+ x2.stride(1),
+ y1.stride(0),
+ y1.stride(1),
+ s_y2.stride(0),
+ s_y2.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ s_y2_max = s_y2.max()
+ qy2, s_y2_max = fp8_division(y1, QB, fp8type, s_y2_max, stochastic=stochastic) # reuse the floating point output y1
+
+ # Recover 2D to 3D
+ if batched:
+ y1 = y1.reshape(BS, -1, y1.shape[-1])
+ qy2 = qy2.reshape(BS, -1, qy2.shape[-1])
+ s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1])
+
+ return y1, (qy2, s_y2_max, s_y2)
diff --git a/llava/model/coat/activation/real_quantization/add_fwd.py b/llava/model/coat/activation/real_quantization/add_fwd.py
new file mode 100644
index 0000000000000000000000000000000000000000..90aed0b3573e4c1a21b4eecea959f6f537f61be8
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/add_fwd.py
@@ -0,0 +1,219 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block
+
+"""Element-wise Add, used in forward pass"""
+"""Input1 (Residual) uses full-precision/BF16"""
+"""Input2 (Backbone) uses full-precision/BF16"""
+"""Output1 uses full-precision/BF16"""
+"""Output2 uses 1 * 16 group quantization"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_add_Ifp_Ifp_Ofp_Og16_kernel(
+ output1_ptr, # output
+ output2_ptr,
+ output2_scale_ptr,
+ input1_ptr, # input
+ input2_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input1_stride_0,
+ input1_stride_1, # input1 stride
+ input2_stride_0,
+ input2_stride_1, # input2 stride
+ output1_stride_0,
+ output1_stride_1, # output stride
+ output2_stride_0,
+ output2_stride_1, # output stride
+ s_output2_stride_0,
+ s_output2_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # --- The first input ---
+ input1_block_ptr = tl.make_block_ptr(
+ base=input1_ptr,
+ shape=(M, N),
+ strides=(input1_stride_0, input1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input1 = tl.load(input1_block_ptr)
+ input1 = input1.to(tl.float32)
+ input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB))
+
+ # --- The second input ---
+ input2_block_ptr = tl.make_block_ptr(
+ base=input2_ptr,
+ shape=(M, N),
+ strides=(input2_stride_0, input2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ input2 = tl.load(input2_block_ptr)
+ input2 = input2.to(tl.float32)
+ input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB))
+
+ # Actual Calculation of Add
+ add_output = input1 + input2
+
+ # Quantize the grad 1 - Scale calculation
+ abs_add_output = tl.abs(add_output)
+ max_val = tl.max(abs_add_output, axis=2) + SCALE_MIN_THRES
+ scale_output2 = max_val / fp8_max
+ scale_output2 = tl.reshape(scale_output2, (BLOCK_M, BLOCK_SN, 1))
+
+ # save the fp add output
+ fp_add_output = add_output.to(output1_ptr.type.element_ty)
+ fp_add_output = tl.reshape(fp_add_output, (BLOCK_M, BLOCK_N))
+
+ # pointers
+ output1_block_ptr = tl.make_block_ptr(
+ base=output1_ptr,
+ shape=(M, N),
+ strides=(output1_stride_0, output1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ tl.store(output1_block_ptr, fp_add_output)
+
+ # Quantize
+ add_output = tl.fdiv(add_output, scale_output2)
+ scale_output2 = scale_output2.to(output2_scale_ptr.type.element_ty)
+ scale_output2 = tl.reshape(scale_output2, (BLOCK_M, BLOCK_SN))
+ add_output = tl.reshape(add_output, (BLOCK_M, BLOCK_N))
+
+ add_output = add_output.to(output2_ptr.type.element_ty)
+ add_output = tl.reshape(add_output, (BLOCK_M, BLOCK_N))
+
+ # pointers
+ output2_block_ptr = tl.make_block_ptr(
+ base=output2_ptr,
+ shape=(M, N),
+ strides=(output2_stride_0, output2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output2_ptr = tl.make_block_ptr(
+ base=output2_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output2_stride_0, s_output2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+ tl.store(output2_block_ptr, add_output, boundary_check=(0, 1))
+ tl.store(scale_output2_ptr, scale_output2, boundary_check=(0, 1))
+
+
+def fp8_add_Ifp_Ifp_Ofp_Og16(x1, x2, fp8type, QB): # suppose x1 is full precision or BF16
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x1.shape) == 3:
+ batched = True
+ BS = x1.shape[0]
+ x1 = x1.reshape(-1, x1.shape[-1])
+ x2 = x2.reshape(-1, x2.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x1.shape
+ SN = int(N / QB) # assume the shape of quantization block size is always 1 * G
+ assert x1.shape == x2.shape
+
+ y1 = torch.empty_like(x1, dtype=torch.bfloat16)
+ y2 = torch.empty_like(x2, dtype=fp8type)
+ s_y2 = torch.empty((M, SN), dtype=torch.bfloat16, device=x2.device)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_add_Ifp_Ifp_Ofp_Og16_kernel[grid](
+ y1,
+ y2,
+ s_y2,
+ x1,
+ x2,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x1.stride(0),
+ x1.stride(1),
+ x2.stride(0),
+ x2.stride(1),
+ y1.stride(0),
+ y1.stride(1),
+ y2.stride(0),
+ y2.stride(1),
+ s_y2.stride(0),
+ s_y2.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ # Recover 2D to 3D
+ if batched:
+ y1 = y1.reshape(BS, -1, y1.shape[-1])
+ y2 = y2.reshape(BS, -1, y2.shape[-1])
+ s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1])
+
+ return y1, (y2, s_y2)
diff --git a/llava/model/coat/activation/real_quantization/common.py b/llava/model/coat/activation/real_quantization/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..00de715cf3e80eddf74ea401e1bbc9811c3e4970
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/common.py
@@ -0,0 +1,59 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+
+SCALE_MIN_THRES = 1e-10
+
+FP8_MAX_VALUE = {
+ torch.float8_e4m3fn: 448.0,
+ torch.float8_e5m2: 57344.0,
+}
+
+convert_str_to_fp8 = {"E4M3": torch.float8_e4m3fn, "E5M2": torch.float8_e5m2}
+convert_fp8_to_embit = {
+ torch.float8_e4m3fn: (4.0, 3.0),
+ torch.float8_e5m2: (5.0, 2.0),
+}
+
+
+def get_configs_io_block():
+ configs = []
+ for nstages in [3, 4, 5]:
+ for block_m in [32, 64]:
+ for block_n in [64, 128]:
+ for nwarps in [4, 8]:
+ configs.append(
+ triton.Config(
+ {"BLOCK_M": block_m, "BLOCK_N": block_n},
+ num_stages=nstages,
+ num_warps=nwarps,
+ )
+ )
+ return configs
+
+
+# from .common import SCALE_MIN_THRES, FP8_MAX_VALUE
+# SCALE_MIN_THRES: tl.constexpr,
+# + SCALE_MIN_THRES
+# SCALE_MIN_THRES=SCALE_MIN_THRES,
diff --git a/llava/model/coat/activation/real_quantization/fp8linear.py b/llava/model/coat/activation/real_quantization/fp8linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..32531fcb4bb63e4720e3fc48fb08b1f2d82d54db
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/fp8linear.py
@@ -0,0 +1,250 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import os
+import time
+from copy import deepcopy
+from dataclasses import dataclass
+
+import matplotlib.pyplot as plt
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd.function import Function
+
+from ..utils import quant_get_local_rank
+from ._division_transpose import fp8_division_transpose
+from ._quantize_pertensor_transpose import fp8_quantize_pertensor_transpose
+from .linear import fp8_linear_backward, fp8_linear_forward
+
+
+@dataclass
+class DefaultArgs:
+ fabit: int
+ fwbit: int
+ bobit: int
+
+
+class FP8Linear(nn.Linear):
+ def __init__(self, in_features, out_features, bias=True, device=None, args=None, layer_idx=0):
+ super().__init__(in_features, out_features, bias, device)
+
+ if args is None: # I do not want to pass a new argument to OLMo so just use this method
+ args = DefaultArgs(
+ fabit=os.environ["FABIT_FP8Linear"],
+ fwbit=os.environ["FWBIT_FP8Linear"],
+ bobit=os.environ["BOBIT_FP8Linear"],
+ )
+ self.args = deepcopy(args)
+
+ if quant_get_local_rank() == 0:
+ print(f"[qlinear debug] Apply QLinear, {layer_idx}")
+
+ self.layer_idx = layer_idx
+ self.layer_name = None
+
+ def forward(self, Input):
+ if self.training:
+ # if False:
+ output = QuantLinearTE.apply(Input, self.weight, self.bias, self.args, self.layer_name)
+ else:
+ output = F.linear(Input, self.weight, self.bias)
+
+ return output
+
+
+# if int(os.environ.get("LOCAL_RANK")) == 0:
+# import IPython
+# IPython.embed()
+# else:
+# import time
+# time.sleep(1000)
+
+# class QuantLinearTE(Function):
+# @staticmethod
+# def forward(ctx, input, weight, bias, args, layer_type):
+# ctx.saved = input, weight, bias, args, layer_type
+# return F.linear(input, weight, bias)
+
+# @staticmethod
+# def backward(ctx, grad_output):
+# input, weight, bias, args, layer_type = ctx.saved
+
+# C_in = input.shape[-1]
+# C_out = grad_output.shape[-1]
+
+# grad_output_flatten = grad_output.reshape(-1, C_out)
+# input_flatten = input.reshape(-1, C_in)
+
+# if grad_output_flatten.dtype == input_flatten.dtype:
+# grad_weight = grad_output_flatten.t().mm(input_flatten)
+# else:
+# grad_weight = grad_output_flatten.float().t().mm(input_flatten)
+
+# if grad_output_flatten.dtype == weight.dtype:
+# grad_input = grad_output_flatten.mm(weight)
+# else:
+# grad_input = grad_output_flatten.float().mm(weight)
+
+# if bias is not None:
+# grad_bias = grad_output_flatten.sum(0)
+# else:
+# grad_bias = None
+
+# grad_input_transform = grad_input.reshape(input.size())
+
+# return grad_input_transform, grad_weight, grad_bias, None, None
+
+
+class QuantLinearTE(Function):
+ @staticmethod
+ @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.bfloat16)
+ def forward(ctx, input, weight, bias, args, layer_name):
+
+ time_bench = os.getenv("TIME_BENCH")
+
+ if time_bench:
+ start_1 = torch.cuda.Event(enable_timing=True)
+ start_1.record()
+
+ # Qinput, Iscale, Qinput_t = fp8_division_transpose(input, 16, args.fabit)
+ Qinput, Iscale, Qinput_t = fp8_quantize_pertensor_transpose(input, 16, args.fabit, transpose_output_2d=True)
+
+ if time_bench:
+ end_1 = torch.cuda.Event(enable_timing=True)
+ end_1.record()
+ start_2 = torch.cuda.Event(enable_timing=True)
+ start_2.record()
+
+ # Qweight, Wscale, Qweight_t = fp8_division_transpose(weight, 16, args.fwbit)
+ Qweight, Wscale, Qweight_t = fp8_quantize_pertensor_transpose(weight, 16, args.fwbit, transpose_output_2d=True)
+
+ if time_bench:
+ end_2 = torch.cuda.Event(enable_timing=True)
+ end_2.record()
+ start_3 = torch.cuda.Event(enable_timing=True)
+ start_3.record()
+
+ ctx.saved = Qinput_t, Iscale, Qweight_t, Wscale, bias, args, layer_name
+ fc_output = fp8_linear_forward(Qinput, Iscale, Qweight, Wscale, False, 0, bias)
+
+ if time_bench:
+ end_3 = torch.cuda.Event(enable_timing=True)
+ end_3.record()
+ start_4 = torch.cuda.Event(enable_timing=True)
+ start_4.record()
+
+ output = F.linear(input, weight, bias)
+
+ end_4 = torch.cuda.Event(enable_timing=True)
+ end_4.record()
+
+ torch.cuda.synchronize()
+ if quant_get_local_rank() == 0:
+ print(
+ f"[Forward] Part 1: {start_1.elapsed_time(end_1):.6f} ms | Part 2: {start_2.elapsed_time(end_2):.6f} ms | Part 3: {start_3.elapsed_time(end_3):.6f} ms | "
+ f"FP8: {start_1.elapsed_time(end_3):.6f} | BF16: {start_4.elapsed_time(end_4):.6f} | Input shape: {input.shape} | Weight shape: {weight.shape}"
+ )
+
+ return fc_output
+
+ @staticmethod
+ @torch.amp.custom_bwd(device_type="cuda")
+ def backward(ctx, grad_output):
+ Qinput_t, Iscale, Qweight_t, Wscale, bias, args, layer_name = ctx.saved
+
+ time_bench = os.getenv("TIME_BENCH")
+ if time_bench:
+ start_1 = torch.cuda.Event(enable_timing=True)
+ start_1.record()
+
+ # Qgrad_output, Gscale, Qgrad_output_t = fp8_division_transpose(grad_output, 16, args.bobit, stochastic=False)
+ Qgrad_output, Gscale, Qgrad_output_t = fp8_quantize_pertensor_transpose(
+ grad_output, 16, args.bobit, stochastic=False, transpose_output_2d=True
+ )
+
+ if time_bench:
+ end_1 = torch.cuda.Event(enable_timing=True)
+ end_1.record()
+ start_2 = torch.cuda.Event(enable_timing=True)
+ start_2.record()
+
+ grad_input, grad_weight = fp8_linear_backward(
+ Qinput_t,
+ Iscale,
+ Qgrad_output,
+ Gscale,
+ Qgrad_output_t,
+ Qweight_t,
+ Wscale,
+ 16,
+ bias,
+ stochastic=False,
+ dgrad_quantize=False,
+ )
+
+ if time_bench:
+ end_2 = torch.cuda.Event(enable_timing=True)
+ end_2.record()
+ start_3 = torch.cuda.Event(enable_timing=True)
+ start_3.record()
+
+ if bias is not None:
+ grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0)
+ else:
+ grad_bias = None
+
+ if time_bench:
+ end_3 = torch.cuda.Event(enable_timing=True)
+ end_3.record()
+
+ # ========== BF16 ==========
+ C_in = Qinput_t.shape[0]
+ C_out = grad_output.shape[-1]
+ grad_output_flatten = grad_output.reshape(-1, C_out)
+ input_flatten = Qinput_t.t().reshape(-1, C_in).to(torch.bfloat16)
+ weight = Qweight_t.t().to(torch.bfloat16)
+
+ start_4 = torch.cuda.Event(enable_timing=True)
+ start_4.record()
+
+ if grad_output_flatten.dtype == input_flatten.dtype:
+ _grad_weight = grad_output_flatten.t().mm(input_flatten)
+ else:
+ _grad_weight = grad_output_flatten.float().t().mm(input_flatten)
+
+ if grad_output_flatten.dtype == weight.dtype:
+ _grad_input = grad_output_flatten.mm(weight)
+ else:
+ _grad_input = grad_output_flatten.float().mm(weight)
+
+ end_4 = torch.cuda.Event(enable_timing=True)
+ end_4.record()
+
+ torch.cuda.synchronize()
+ if quant_get_local_rank() == 0:
+ print(
+ f"[Backward] Part 1: {start_1.elapsed_time(end_1):.6f} ms | Part 2: {start_2.elapsed_time(end_2):.6f} ms | Part 3: {start_3.elapsed_time(end_3):.6f} ms | "
+ f"FP8: {start_1.elapsed_time(end_3):.6f} | BF16: {start_4.elapsed_time(end_4):.6f} | Input shape: {Qinput_t.shape} | Weight shape: {weight.shape}"
+ )
+
+ return grad_input, grad_weight, grad_bias, None, None
diff --git a/llava/model/coat/activation/real_quantization/func_layernorm_noparam.py b/llava/model/coat/activation/real_quantization/func_layernorm_noparam.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3a1f57f3d1db8b19da687fe823a83f35d97d7d6
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/func_layernorm_noparam.py
@@ -0,0 +1,303 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Layer Normalization
+====================
+In this tutorial, you will write a high-performance layer normalization
+kernel that runs faster than the PyTorch implementation.
+
+In doing so, you will learn about:
+
+* Implementing backward pass in Triton.
+
+* Implementing parallel reduction in Triton.
+
+"""
+
+import torch
+import triton
+import triton.language as tl
+
+from ._division import _stochastic_rounding
+from ._division_transpose import fp8_division_transpose
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit
+
+"""FP8 LayerNorm. Forward + Backward"""
+"""Forward: input uses 1 * 16 quantization"""
+"""Forward: output use per-tensor quantization."""
+"""Backward: input uses full-precision/BF16."""
+"""Backward: output uses full-precision/BF16"""
+"""Support 3D input, but need to first flatten to 2D to perform calculation."""
+
+
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["N"] // args["QB"],
+ "BLOCK_SN2": lambda args: args["N2"] // args["QB"],
+ }
+)
+@triton.jit
+def _layer_norm_fwd_fused(
+ X, # pointer to the input
+ SX, # pointer to the scale of input
+ Y, # pointer to the output
+ SY, # pointer to the scale of output
+ Mean, # pointer to the mean
+ Rstd, # pointer to the 1/std
+ stride, # how much to increase the pointer when moving by 1 row
+ scale_stride, # how much to increase the pointer when moving by 1 row
+ N: tl.constexpr, # number of columns in X,
+ N2: tl.constexpr, # number of columns in X,
+ SN: tl.constexpr,
+ SN2: tl.constexpr,
+ QB: tl.constexpr,
+ eps, # epsilon to avoid division by zero
+ fp8_max,
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+ BLOCK_SN2: tl.constexpr,
+):
+ # Map the program id to the row of X and Y it should compute.
+ row = tl.program_id(0)
+ Y += row * stride
+ X += row * stride
+ SX += row * scale_stride
+
+ mean = 0
+ cols = tl.arange(0, N2)
+ scale_cols = tl.arange(0, SN2)
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
+ scale_x = tl.load(SX + scale_cols, mask=scale_cols < SN, other=0.0).to(tl.float32)
+
+ # Dequantize and swish calculation
+ scale_x = tl.reshape(scale_x, (BLOCK_SN2, 1))
+ x = tl.reshape(x, (BLOCK_SN2, QB))
+ x = x * scale_x
+ x = tl.reshape(x, (N2,))
+
+ # Compute mean and Variance
+ mean = tl.sum(x, axis=0) / N
+ # Compute variance
+ _var = (x - mean) * (x - mean)
+ var = tl.sum(_var, axis=0) / N
+ rstd = 1 / tl.sqrt(var + eps)
+ # Write mean / rstd
+ tl.store(Mean + row, mean)
+ tl.store(Rstd + row, rstd)
+ # Normalize and apply linear transformation
+ x_hat = (x - mean) * rstd
+
+ # Scale calculation
+ abs_x_hat = tl.abs(x_hat)
+ max_val = tl.max(abs_x_hat, axis=0) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = scale_output.to(SY.type.element_ty)
+
+ tl.store(SY + row, scale_output)
+
+ # Write output
+ tl.store(Y + cols, x_hat, mask=cols < N)
+
+
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["N"] // args["QB"],
+ "BLOCK_SN2": lambda args: args["N2"] // args["QB"],
+ }
+)
+@triton.jit
+def _layer_norm_bwd_dx_fused(
+ DX, # pointer to the input gradient
+ DY, # pointer to the output gradient
+ X, # pointer to the input
+ SX, # pointer to the input
+ noise_ptr, # noise for stochastic
+ Mean, # pointer to the mean
+ Rstd, # pointer to the 1/std
+ stride, # how much to increase the pointer when moving by 1 row
+ scale_stride, # how much to increase the pointer when moving by 1 row
+ N: tl.constexpr, # number of columns in X
+ N2: tl.constexpr, # number of columns in X
+ SN: tl.constexpr,
+ SN2: tl.constexpr,
+ QB: tl.constexpr,
+ SCALE_MIN_THRES: tl.constexpr,
+ STOCHASTIC: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+ BLOCK_SN2: tl.constexpr,
+):
+ # Map the program id to the elements of X, DX, and DY it should compute.
+ row = tl.program_id(0)
+ cols = tl.arange(0, N2)
+ scale_cols = tl.arange(0, SN2)
+ mask = cols < N
+ X += row * stride
+ DY += row * stride
+ DX += row * stride
+ SX += row * scale_stride
+ # Load data to SRAM
+ x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
+ scale_x = tl.load(SX + scale_cols, mask=scale_cols < SN, other=0.0).to(tl.float32)
+ scale_x = tl.reshape(scale_x, (BLOCK_SN2, 1))
+ x = tl.reshape(x, (BLOCK_SN2, QB))
+ x = x * scale_x
+ x = tl.reshape(x, N2)
+
+ dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32)
+
+ mean = tl.load(Mean + row)
+ rstd = tl.load(Rstd + row)
+ # Compute dx
+ xhat = (x - mean) * rstd
+ xhat = tl.where(mask, xhat, 0.0)
+ dy = tl.where(mask, dy, 0.0)
+ c1 = tl.sum(xhat * dy, axis=0) / N
+ c2 = tl.sum(dy, axis=0) / N
+ dx = (dy - (xhat * c1 + c2)) * rstd
+
+ if STOCHASTIC:
+ # noise_ptr += row * stride
+ # noise_block_ptr = noise_ptr + cols
+ # noise = tl.load(noise_block_ptr, mask=mask, other=0.)
+
+ noise_offset = row * stride + cols
+ noise = tl.rand(0, noise_offset)
+
+ dx = _stochastic_rounding(dx, noise, e_bit, m_bit)
+
+ dx = dx.to(DX.type.element_ty)
+
+ # Write dx
+ tl.store(DX + cols, dx, mask=mask)
+
+
+def fp8_layernorm_noparam_forward(x, s_x, QB, eps, transpose_output_2d=False):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ assert len(s_x.shape) == 3
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+ s_x = s_x.reshape(-1, s_x.shape[-1])
+
+ # allocate output
+ M, N = x.shape
+ _, SN = s_x.shape
+ y = torch.empty_like(x, dtype=torch.float32)
+ s_y = torch.empty(
+ (M,), dtype=torch.bfloat16, device=x.device
+ ) # We do this because we apply per-tensor quantization for it afterwards
+ mean = torch.empty((M,), dtype=torch.float32, device=x.device)
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
+ # heuristics for number of warps
+ num_warps = 8
+ fp8MaxValue = FP8_MAX_VALUE[x.dtype]
+
+ N2 = triton.next_power_of_2(N)
+ SN2 = N2 // QB
+ # enqueue kernel
+ _layer_norm_fwd_fused[(M,)]( #
+ x,
+ s_x,
+ y,
+ s_y,
+ mean,
+ rstd, #
+ x.stride(0),
+ s_x.stride(0),
+ N,
+ N2,
+ SN,
+ SN2,
+ QB,
+ eps,
+ fp8MaxValue,
+ SCALE_MIN_THRES=SCALE_MIN_THRES, #
+ num_warps=num_warps,
+ num_ctas=1,
+ )
+ # reduction
+ s_y_max = s_y.max()
+ qy, s_y_max, qy_t = fp8_division_transpose(y, QB, x.dtype, s_y_max)
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ qy = qy.reshape(BS, -1, y.shape[-1])
+ if not transpose_output_2d:
+ qy_t = qy_t.reshape(BS, -1, y.shape[-1])
+
+ return qy, s_y_max, qy_t, (mean, rstd, num_warps)
+
+
+def fp8_layernorm_noparam_backward(x, s_x, g, QB, m, v, num_warps, stochastic=False):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ assert len(s_x.shape) == 3
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+ s_x = s_x.reshape(-1, s_x.shape[-1])
+ g = g.reshape(-1, g.shape[-1])
+
+ if stochastic:
+ # noise = torch.empty_like(g, dtype=torch.float32).uniform_(-0.5, 0.5)
+ noise = None
+ else:
+ noise = None
+
+ # heuristics for amount of parallel reduction stream for DW/DB
+ dx = torch.empty_like(g, dtype=torch.bfloat16)
+ # enqueue kernel using forward pass heuristics
+ # also compute partial sums for DW and DB
+ M, N = g.shape
+ _, SN = s_x.shape
+
+ N2 = triton.next_power_of_2(N)
+ SN2 = triton.next_power_of_2(SN)
+ _layer_norm_bwd_dx_fused[(M,)]( #
+ dx,
+ g,
+ x,
+ s_x,
+ noise,
+ m,
+ v, #
+ x.stride(0),
+ s_x.stride(0),
+ N,
+ N2,
+ SN,
+ SN2,
+ QB,
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ STOCHASTIC=stochastic,
+ num_warps=num_warps,
+ )
+
+ if batched:
+ dx = dx.reshape(BS, -1, dx.shape[-1])
+
+ return dx
diff --git a/llava/model/coat/activation/real_quantization/func_quantize.py b/llava/model/coat/activation/real_quantization/func_quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..f63942146890a02132b431a10df39f22a65ebec4
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/func_quantize.py
@@ -0,0 +1,102 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from copy import deepcopy
+
+import torch
+import torch.nn as nn
+
+from ._quantize import fp8_quantize
+from ._quantize_pertensor import fp8_quantize_pertensor
+
+
+class Coat_quantize_bgn(nn.Module):
+ def __init__(self, args=None, layer_type=""):
+ super().__init__()
+ self.args = deepcopy(args)
+ self.fp8type = self.args.fabit
+ self.layer_type = layer_type
+
+ def forward(self, input):
+ if self.training:
+ return Coat_quantize_bgn_func.apply(input, self.args.group_size, self.fp8type)
+ else:
+ return input, None, None
+
+
+class Coat_quantize_bgn_func(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, group_size, fp8type):
+ """
+ (Qoutput, Oscale) uses 1 * 16 quantization
+ """
+ Qoutput, Oscale = fp8_quantize(input, group_size, fp8type)
+ # For autograd
+ Qoutput = Qoutput.view(torch.float8_e4m3fn)
+ ctx.saved = group_size
+ return input, Qoutput, Oscale
+
+ @staticmethod
+ def backward(ctx, grad_output, Qgrad_output, Gscale):
+ """
+ (Qgrad_output, Gscale) uses 1 * 16 quantization
+ """
+ return grad_output, None, None
+
+
+class Coat_quantize_end(nn.Module):
+ def __init__(self, args=None, layer_type=""):
+ super().__init__()
+ self.args = deepcopy(args)
+ self.fp8type = self.args.babit
+ self.layer_type = layer_type
+
+ def forward(self, input, Qinput, Iscale):
+ if self.training:
+ return Coat_quantize_end_func.apply(input, Qinput, Iscale, self.args.group_size, self.fp8type)
+ else:
+ return input
+
+
+class Coat_quantize_end_func(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, Qinput, Iscale, group_size, fp8type):
+ """
+ (Qinput, Iscale) uses 1 * 16 quantization
+ """
+ ctx.saved = group_size, fp8type
+
+ return input
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ """
+ (Qgrad_output, Gscale) uses per-tensor quantization
+ """
+
+ group_size, fp8type = ctx.saved
+ Qgrad_output, Gscale, Gscale_g16 = fp8_quantize_pertensor(grad_output, group_size, fp8type, stochastic=False)
+
+ # For autograd
+ Qgrad_output = Qgrad_output.view(torch.float8_e4m3fn)
+
+ return grad_output, Qgrad_output, Gscale_g16, None, None
diff --git a/llava/model/coat/activation/real_quantization/func_rmsnorm.py b/llava/model/coat/activation/real_quantization/func_rmsnorm.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7a8d38e2909d7a45a7792e9b1502af5aa7dee95
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/func_rmsnorm.py
@@ -0,0 +1,345 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Layer Normalization
+====================
+In this tutorial, you will write a high-performance layer normalization
+kernel that runs faster than the PyTorch implementation.
+
+In doing so, you will learn about:
+
+* Implementing backward pass in Triton.
+
+* Implementing parallel reduction in Triton.
+
+"""
+
+import torch
+import triton
+import triton.language as tl
+
+from ._division_transpose import fp8_division_transpose
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES
+
+"""RMSNorm Forward + Backward"""
+"""Forward: Input uses 1 * 16 group quantization"""
+"""Forward: Output uses per-tensor quantization"""
+"""Backward: Input uses full-precision/BF16."""
+"""Backward: Output uses full-precision/BF16."""
+
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["N"] // args["QB"],
+ "BLOCK_SN2": lambda args: args["N2"] // args["QB"],
+ }
+)
+@triton.jit
+def _rms_norm_fwd_fused(
+ X, # pointer to the input
+ SX, # pointer to the scale of input
+ Y, # pointer to the output
+ SY, # pointer to the scale of output
+ W, # Weight
+ Rstd, # pointer to the 1/std
+ stride, # how much to increase the pointer when moving by 1 row
+ scale_stride, # how much to increase the pointer when moving by 1 row
+ N: tl.constexpr, # number of columns in X,
+ N2: tl.constexpr, # number of columns in X,
+ SN: tl.constexpr,
+ SN2: tl.constexpr,
+ QB: tl.constexpr,
+ eps, # epsilon to avoid division by zero
+ fp8_max,
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+ BLOCK_SN2: tl.constexpr,
+):
+ # Map the program id to the row of X and Y it should compute.
+ row = tl.program_id(0)
+ Y += row * stride
+ X += row * stride
+ SX += row * scale_stride
+
+ cols = tl.arange(0, N2)
+ scale_cols = tl.arange(0, SN2)
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
+ scale_x = tl.load(SX + scale_cols, mask=scale_cols < SN, other=0.0).to(tl.float32)
+
+ # Dequantize and swish calculation
+ scale_x = tl.reshape(scale_x, (BLOCK_SN2, 1))
+ x = tl.reshape(x, (BLOCK_SN2, QB))
+ x = x * scale_x
+ x = tl.reshape(x, N2)
+
+ # Compute variance
+ _var = x * x
+ var = tl.sum(_var, axis=0) / N
+ rstd = 1 / tl.sqrt(var + eps)
+
+ # Write mean / rstd
+ tl.store(Rstd + row, rstd)
+
+ # Normalize and apply linear transformation
+ w = tl.load(W + cols, mask=cols < N, other=0.0)
+ x_hat = x * rstd
+ y = x_hat * w
+
+ # Scale calculation
+ abs_y = tl.abs(y)
+ max_val = tl.max(abs_y) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ tl.store(SY + row, scale_output)
+
+ # Write output
+ tl.store(Y + cols, y, mask=cols < N)
+
+
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["N"] // args["QB"],
+ "BLOCK_SN2": lambda args: args["N2"] // args["QB"],
+ }
+)
+@triton.jit
+def _rms_norm_bwd_dx_fused(
+ DX, # pointer to the input gradient
+ DY, # pointer to the output gradient
+ DW,
+ X, # pointer to the input
+ SX, # pointer to the input
+ W, # weight
+ Rstd, # pointer to the 1/std
+ Lock,
+ stride, # how much to increase the pointer when moving by 1 row
+ scale_stride, # how much to increase the pointer when moving by 1 row
+ N: tl.constexpr, # number of columns in X,
+ N2: tl.constexpr, # number of columns in X,
+ SN: tl.constexpr,
+ SN2: tl.constexpr,
+ QB: tl.constexpr,
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+ BLOCK_SN2: tl.constexpr,
+ GROUP_SIZE_M: tl.constexpr,
+):
+ # Map the program id to the elements of X, DX, and DY it should compute.
+ row = tl.program_id(0)
+ cols = tl.arange(0, N2)
+ scale_cols = tl.arange(0, SN2)
+ mask = cols < N
+ X += row * stride
+ DY += row * stride
+ DX += row * stride
+ SX += row * scale_stride
+
+ # Offset locks and weights/biases gradient pointer for parallel reduction
+ lock_id = row % GROUP_SIZE_M
+ Lock += lock_id
+ Count = Lock + GROUP_SIZE_M
+ DW = DW + lock_id * N + cols
+
+ # Load data to SRAM
+ x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
+ scale_x = tl.load(SX + scale_cols, mask=scale_cols < SN, other=0.0).to(tl.float32)
+ scale_x = tl.reshape(scale_x, (BLOCK_SN2, 1))
+ x = tl.reshape(x, (BLOCK_SN2, QB))
+ x = x * scale_x
+ x = tl.reshape(x, N2)
+ dy = tl.load(DY + cols, mask=mask, other=0.0).to(tl.float32)
+
+ # Load weight
+ w = tl.load(W + cols, mask=cols < N, other=0.0).to(tl.float32)
+ rstd = tl.load(Rstd + row).to(tl.float32)
+
+ # Compute dx
+ xhat = x * rstd
+ wdy = w * dy
+ xhat = tl.where(mask, xhat, 0.0)
+ wdy = tl.where(mask, wdy, 0.0)
+ c1 = tl.sum(xhat * wdy, axis=0) / N
+ dx = (wdy - (xhat * c1)) * rstd # layer norm have c2 term, rmsnorm do not
+
+ dx = dx.to(DX.type.element_ty)
+
+ # Write dx
+ tl.store(DX + cols, dx, mask=mask)
+
+ # Accumulate partial sums for dw/db
+ partial_dw = (dy * xhat).to(w.dtype)
+ while tl.atomic_cas(Lock, 0, 1) == 1:
+ pass
+ count = tl.load(Count)
+ # First store doesn't accumulate
+ if count == 0:
+ tl.atomic_xchg(Count, 1)
+ else:
+ partial_dw += tl.load(DW, mask=mask)
+ tl.store(DW, partial_dw, mask=mask)
+ # Release the lock
+ tl.atomic_xchg(Lock, 0)
+
+
+@triton.jit
+def _rms_norm_bwd_dwdb(
+ DW, # pointer to the partial sum of weights gradient
+ FINAL_DW, # pointer to the weights gradient
+ M, # GROUP_SIZE_M
+ N, # number of columns
+ BLOCK_SIZE_M: tl.constexpr,
+ BLOCK_SIZE_N: tl.constexpr,
+):
+ # Map the program id to the elements of DW and DB it should compute.
+ pid = tl.program_id(0)
+ cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
+ dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+ # Iterate through the rows of DW and DB to sum the partial sums.
+ for i in range(0, M, BLOCK_SIZE_M):
+ rows = i + tl.arange(0, BLOCK_SIZE_M)
+ mask = (rows[:, None] < M) & (cols[None, :] < N)
+ offs = rows[:, None] * N + cols[None, :]
+ dw += tl.load(DW + offs, mask=mask, other=0.0)
+ # Write the final sum to the output.
+ sum_dw = tl.sum(dw, axis=0)
+ tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
+
+
+def fp8_rmsnorm_forward(x, s_x, w, QB, eps, transpose_output_2d=False):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ assert len(s_x.shape) == 3
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+ s_x = s_x.reshape(-1, s_x.shape[-1])
+
+ # allocate output
+ M, N = x.shape
+ _, SN = s_x.shape
+ y = torch.empty_like(x, dtype=torch.bfloat16)
+ s_y = torch.empty((M,), dtype=torch.bfloat16, device=x.device)
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
+ # heuristics for number of warps
+ num_warps = 8
+ fp8MaxValue = FP8_MAX_VALUE[x.dtype]
+
+ N2 = triton.next_power_of_2(N)
+ SN2 = N2 // QB
+
+ # import os
+ # if int(os.environ.get("LOCAL_RANK")) == 7:
+ # print(x.device, x.shape, x.dtype, s_x.shape, s_x.dtype, y.shape, y.dtype, s_y.shape, s_y.dtype, w.shape, w.dtype, rstd.shape, rstd.dtype, x.stride(0), s_x.stride(0), N, N2, SN, SN2, QB, "\n")
+ # enqueue kernel
+ _rms_norm_fwd_fused[(M,)]( #
+ x,
+ s_x,
+ y,
+ s_y,
+ w,
+ rstd, #
+ x.stride(0),
+ s_x.stride(0),
+ N,
+ N2,
+ SN,
+ SN2,
+ QB,
+ eps,
+ fp8MaxValue,
+ SCALE_MIN_THRES=SCALE_MIN_THRES, #
+ num_warps=num_warps,
+ num_ctas=1,
+ )
+
+ # reduction
+ s_y_max = s_y.max()
+ qy, s_y_max, qy_t = fp8_division_transpose(y, QB, x.dtype, s_y_max)
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ qy = qy.reshape(BS, -1, y.shape[-1])
+ if not transpose_output_2d:
+ qy_t = qy_t.reshape(BS, -1, y.shape[-1])
+
+ return qy, s_y_max, qy_t, (w.clone(), rstd, num_warps)
+
+
+def fp8_rmsnorm_backward(x, s_x, g, w, v, QB, num_warps):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ assert len(s_x.shape) == 3
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+ s_x = s_x.reshape(-1, s_x.shape[-1])
+ g = g.reshape(-1, g.shape[-1])
+
+ # enqueue kernel using forward pass heuristics
+ # also compute partial sums for DW and DB
+ M, N = g.shape
+ _, SN = s_x.shape
+
+ GROUP_SIZE_M = 128
+ # heuristics for amount of parallel reduction stream for DW/DB
+ locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device=w.device)
+ _dw = torch.zeros((GROUP_SIZE_M, N), dtype=w.dtype, device=w.device)
+ dw = torch.empty((N,), dtype=w.dtype, device=w.device)
+
+ dx = torch.empty_like(g, dtype=torch.bfloat16)
+
+ N2 = triton.next_power_of_2(N)
+ SN2 = triton.next_power_of_2(SN)
+ _rms_norm_bwd_dx_fused[(M,)]( #
+ dx,
+ g,
+ _dw,
+ x,
+ s_x,
+ w,
+ v,
+ locks, #
+ x.stride(0),
+ s_x.stride(0),
+ N,
+ N2,
+ SN,
+ SN2,
+ QB,
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ num_warps=num_warps,
+ GROUP_SIZE_M=GROUP_SIZE_M,
+ )
+
+ if batched:
+ dx = dx.reshape(BS, -1, dx.shape[-1])
+
+ grid = lambda meta: [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
+ # accumulate partial sums in separate kernel
+ _rms_norm_bwd_dwdb[grid](_dw, dw, min(GROUP_SIZE_M, M), N, BLOCK_SIZE_M=32, BLOCK_SIZE_N=128, num_ctas=1) # #
+
+ return dx, dw
diff --git a/llava/model/coat/activation/real_quantization/gelu_bwd.py b/llava/model/coat/activation/real_quantization/gelu_bwd.py
new file mode 100644
index 0000000000000000000000000000000000000000..8772707f7f5e6ef9b68b9c29f2b1b329f2c3177d
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/gelu_bwd.py
@@ -0,0 +1,235 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from ._division import fp8_division
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_str_to_fp8, get_configs_io_block
+
+"""GELU Activation Backward"""
+"""Input uses 1 * 16 group quantization"""
+"""Grad uses full-precision/BF16"""
+"""Output uses per-tensor quantization"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_gelu_backward_kernel(
+ output_ptr,
+ output_scale_ptr, # output
+ input_ptr,
+ input_scale_ptr, # input
+ grad_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ s_input_stride_0,
+ s_input_stride_1, # scale of input stride
+ grad_stride_0,
+ grad_stride_1, # input stride
+ output_stride_0,
+ output_stride_1, # output stride
+ s_output_stride_0,
+ s_output_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input_ptr = tl.make_block_ptr(
+ base=input_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input_stride_0, s_input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ scale_input = tl.load(scale_input_ptr)
+
+ input = input.to(tl.float32)
+ scale_input = scale_input.to(tl.float32)
+
+ # Dequantize and gelu calculation
+ scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1))
+ input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+ input = input * scale_input
+
+ # pointers of gradient
+ grad_block_ptr = tl.make_block_ptr(
+ base=grad_ptr,
+ shape=(M, N),
+ strides=(grad_stride_0, grad_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ grad = tl.load(grad_block_ptr)
+ grad = grad.to(tl.float32)
+ grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB))
+
+ # Actual Calculation of GELU's backward
+ pi = float(torch.pi)
+ cdf = (1.0 + tl.math.erf(input / tl.math.sqrt(2.0))) / 2
+ exp = input * tl.exp(-libdevice.pow(input, 2) / 2) / tl.sqrt(2 * pi)
+ gelu_output = cdf + exp
+
+ gelu_output = gelu_output * grad
+
+ # Quantize Scale calculation
+ abs_output = tl.abs(gelu_output)
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
+
+ # Quantize
+ # gelu_output = tl.fdiv(gelu_output, scale_output)
+ gelu_output = gelu_output.to(output_ptr.type.element_ty)
+
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
+ gelu_output = tl.reshape(gelu_output, (BLOCK_M, BLOCK_N))
+
+ # debug
+ # gelu_output = input
+ # scale_output = scale_input
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output_ptr = tl.make_block_ptr(
+ base=output_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output_stride_0, s_output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ tl.store(output_block_ptr, gelu_output, boundary_check=(0, 1))
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
+
+
+def fp8_gelu_backward(x, s_x, g, QB, fp8type):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ assert len(s_x.shape) == 3
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+ s_x = s_x.reshape(-1, s_x.shape[-1])
+ g = g.reshape(-1, g.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x.shape
+ _, SN = s_x.shape # assume the shape of quantization block size is always 1 * G
+
+ if isinstance(fp8type, str):
+ fp8type = convert_str_to_fp8[fp8type]
+ y = torch.empty_like(g, dtype=torch.bfloat16)
+ s_y = torch.empty_like(s_x, dtype=torch.bfloat16)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_gelu_backward_kernel[grid](
+ y,
+ s_y,
+ x,
+ s_x,
+ g,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x.stride(0),
+ x.stride(1),
+ s_x.stride(0),
+ s_x.stride(1),
+ g.stride(0),
+ g.stride(1),
+ y.stride(0),
+ y.stride(1),
+ s_y.stride(0),
+ s_y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ # Per-tensor quantization
+ s_y_max = s_y.max()
+ qy, s_y_max = fp8_division(y, QB, fp8type, s_y_max)
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ qy = qy.reshape(BS, -1, qy.shape[-1])
+ s_y = s_y.reshape(BS, -1, s_y.shape[-1])
+
+ return qy, s_y_max
diff --git a/llava/model/coat/activation/real_quantization/gelu_bwd_legacy.py b/llava/model/coat/activation/real_quantization/gelu_bwd_legacy.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f2b36bcb8c51d33392cd740c3dcbf6de4c9f082
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/gelu_bwd_legacy.py
@@ -0,0 +1,257 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from ._division import fp8_division
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block
+
+"""GELU Activation Backward"""
+"""Input uses 1 * 16 group quantization"""
+"""Grad uses 1 * 16 group quantization"""
+"""Output uses per-tensor quantization"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_gelu_backward_legacy_kernel(
+ output_ptr,
+ output_scale_ptr, # output
+ input_ptr,
+ input_scale_ptr, # input
+ grad_ptr,
+ grad_scale_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ s_input_stride_0,
+ s_input_stride_1, # scale of input stride
+ grad_stride_0,
+ grad_stride_1, # input stride
+ s_grad_stride_0,
+ s_grad_stride_1, # scale of input stride
+ output_stride_0,
+ output_stride_1, # output stride
+ s_output_stride_0,
+ s_output_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input_ptr = tl.make_block_ptr(
+ base=input_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input_stride_0, s_input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ scale_input = tl.load(scale_input_ptr)
+
+ input = input.to(tl.float32)
+ scale_input = scale_input.to(tl.float32)
+
+ # Dequantize and gelu calculation
+ scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1))
+ input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+ input = input * scale_input
+
+ # pointers of gradient
+ grad_block_ptr = tl.make_block_ptr(
+ base=grad_ptr,
+ shape=(M, N),
+ strides=(grad_stride_0, grad_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # grad ptr
+ scale_grad_ptr = tl.make_block_ptr(
+ base=grad_scale_ptr,
+ shape=(M, SN),
+ strides=(s_grad_stride_0, s_grad_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ grad = tl.load(grad_block_ptr)
+ scale_grad = tl.load(scale_grad_ptr)
+
+ grad = grad.to(tl.float32)
+ scale_grad = scale_grad.to(tl.float32)
+
+ # Dequantize and gelu calculation
+ scale_grad = tl.reshape(scale_grad, (BLOCK_M, BLOCK_SN, 1))
+ grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB))
+ grad = grad * scale_grad
+
+ # Actual Calculation of GELU's backward
+ pi = float(torch.pi)
+ cdf = (1.0 + tl.math.erf(input / tl.math.sqrt(2.0))) / 2
+ exp = input * tl.exp(-libdevice.pow(input, 2) / 2) / tl.sqrt(2 * pi)
+ gelu_output = cdf + exp
+
+ gelu_output = gelu_output * grad
+
+ # Quantize Scale calculation
+ abs_output = tl.abs(gelu_output)
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
+
+ # Quantize
+ # gelu_output = tl.fdiv(gelu_output, scale_output)
+ gelu_output = gelu_output.to(output_ptr.type.element_ty)
+
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
+ gelu_output = tl.reshape(gelu_output, (BLOCK_M, BLOCK_N))
+
+ # debug
+ # gelu_output = input
+ # scale_output = scale_input
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output_ptr = tl.make_block_ptr(
+ base=output_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output_stride_0, s_output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ tl.store(output_block_ptr, gelu_output, boundary_check=(0, 1))
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
+
+
+def fp8_gelu_backward_legacy(x, s_x, g, s_g, QB):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ assert len(s_x.shape) == 3
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+ s_x = s_x.reshape(-1, s_x.shape[-1])
+ g = g.reshape(-1, g.shape[-1])
+ s_g = s_g.reshape(-1, s_g.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x.shape
+ _, SN = s_x.shape # assume the shape of quantization block size is always 1 * G
+
+ y = torch.empty_like(g, dtype=torch.bfloat16)
+ s_y = torch.empty_like(s_g, dtype=s_g.dtype)
+ fp8MaxValue = FP8_MAX_VALUE[g.dtype] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_gelu_backward_legacy_kernel[grid](
+ y,
+ s_y,
+ x,
+ s_x,
+ g,
+ s_g,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x.stride(0),
+ x.stride(1),
+ s_x.stride(0),
+ s_x.stride(1),
+ g.stride(0),
+ g.stride(1),
+ s_g.stride(0),
+ s_g.stride(1),
+ y.stride(0),
+ y.stride(1),
+ s_y.stride(0),
+ s_y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ # Per-tensor quantization
+ s_y_max = s_y.max()
+ qy, s_y_max = fp8_division(y, QB, g.dtype, s_y_max)
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ qy = qy.reshape(BS, -1, qy.shape[-1])
+ s_y = s_y.reshape(BS, -1, s_y.shape[-1])
+
+ return qy, s_y_max
diff --git a/llava/model/coat/activation/real_quantization/gelu_fwd.py b/llava/model/coat/activation/real_quantization/gelu_fwd.py
new file mode 100644
index 0000000000000000000000000000000000000000..547c6f8339feef79870ade12a407c852631ae136
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/gelu_fwd.py
@@ -0,0 +1,209 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from ._division_transpose import fp8_division_transpose
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block
+
+"""GELU Activation Forward"""
+"""Input uses 1 * 16 group quantization"""
+"""Output uses 1 * 16 group quantization"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+__all__ = ["fp8_gelu_forward"]
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_gelu_forward_kernel(
+ output_ptr,
+ output_scale_ptr, # output
+ input_ptr,
+ input_scale_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ s_input_stride_0,
+ s_input_stride_1, # scale of input stride
+ output_stride_0,
+ output_stride_1, # output stride
+ s_output_stride_0,
+ s_output_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input_ptr = tl.make_block_ptr(
+ base=input_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input_stride_0, s_input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ scale_input = tl.load(scale_input_ptr)
+
+ input = input.to(tl.float32)
+ scale_input = scale_input.to(tl.float32)
+
+ # Dequantize and gelu calculation
+ scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1))
+ input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+ input = input * scale_input
+
+ # Actual Calculation of GeLU
+ cdf = (1.0 + tl.math.erf(input / tl.math.sqrt(2.0))) / 2
+ gelu_output = cdf * input
+
+ # Quantize Scale calculation
+ abs_output = tl.abs(gelu_output)
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
+
+ # Quantize
+ # gelu_output = tl.fdiv(gelu_output, scale_output)
+ gelu_output = gelu_output.to(output_ptr.type.element_ty)
+
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
+ gelu_output = tl.reshape(gelu_output, (BLOCK_M, BLOCK_N))
+
+ # debug
+ # gelu_output = input
+ # scale_output = scale_input
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output_ptr = tl.make_block_ptr(
+ base=output_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output_stride_0, s_output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ tl.store(output_block_ptr, gelu_output, boundary_check=(0, 1))
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
+
+
+def fp8_gelu_forward(x, s_x, QB, transpose_output_2d=False):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ assert len(s_x.shape) == 3
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+ s_x = s_x.reshape(-1, s_x.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x.shape
+ _, SN = s_x.shape # assume the shape of quantization block size is always 1 * G
+
+ y = torch.empty_like(x, dtype=torch.bfloat16)
+ s_y = torch.empty_like(s_x, dtype=s_x.dtype)
+ fp8MaxValue = FP8_MAX_VALUE[x.dtype] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_gelu_forward_kernel[grid](
+ y,
+ s_y,
+ x,
+ s_x,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x.stride(0),
+ x.stride(1),
+ s_x.stride(0),
+ s_x.stride(1),
+ y.stride(0),
+ y.stride(1),
+ s_y.stride(0),
+ s_y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ s_y_max = s_y.max()
+ qy, s_y_max, qy_t = fp8_division_transpose(y, QB, x.dtype, s_y_max)
+
+ # Recover 2D to 3D
+ if batched:
+ qy = qy.reshape(BS, -1, qy.shape[-1])
+ s_y = s_y.reshape(BS, -1, s_y.shape[-1])
+ if not transpose_output_2d:
+ qy_t = qy_t.reshape(BS, -1, qy.shape[-1])
+
+ return qy, s_y_max, qy_t
diff --git a/llava/model/coat/activation/real_quantization/linear.py b/llava/model/coat/activation/real_quantization/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..d643f18f51432efc3a2f80c55b73a18e99f0eb75
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/linear.py
@@ -0,0 +1,307 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+try:
+ from ._division import _stochastic_rounding
+ from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8
+except:
+ from common import SCALE_MIN_THRES, FP8_MAX_VALUE, convert_str_to_fp8, convert_fp8_to_embit
+ from COAT.coat.activation.real_quantization._division import _stochastic_rounding
+
+import os
+import time
+
+"""Linear Layer Forward + Backward"""
+"""Input uses per-tensor quantization"""
+"""Output is full-precision/BF16 (for FlashAttention) or 1 * 16 quantization (for the rest)"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+def get_configs_io_block():
+ configs = []
+ for nstages in [3]:
+ for block_m in [128, 256]:
+ for block_n in [128, 256]:
+ for block_k in [128, 256]:
+ for nwarps in [8]:
+ configs.append(
+ triton.Config(
+ {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k},
+ num_stages=nstages,
+ num_warps=nwarps,
+ )
+ )
+ return configs
+
+
+# @triton.autotune(
+# configs=get_configs_io_block(),
+# key=["M", "N", "K"],
+# )
+@triton.jit
+def _fp8matmul_kernel(
+ A,
+ B,
+ C,
+ noise_ptr, # noise for stochastic
+ M: tl.constexpr,
+ N: tl.constexpr,
+ K: tl.constexpr, #
+ stride_am,
+ stride_ak, #
+ stride_bk,
+ stride_bn, #
+ stride_cm,
+ stride_cn, ##
+ Scale_A,
+ Scale_B,
+ Scale_C,
+ stride_scm,
+ stride_scn,
+ output_quantize: tl.constexpr,
+ QB: tl.constexpr, # default to use 1 * 16 quantization
+ BIAS,
+ fp8_max: tl.constexpr,
+ e_bit: tl.constexpr,
+ m_bit: tl.constexpr,
+ SCALE_MIN_THRES: tl.constexpr,
+ STOCHASTIC: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_M: tl.constexpr,
+):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ grid_m = tl.cdiv(M, BLOCK_M)
+ grid_n = tl.cdiv(N, BLOCK_N)
+ # re-order program ID for better L2 performance
+ width = GROUP_M * grid_n
+ group_id = pid // width
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
+ pid_m = group_id * GROUP_M + (pid % group_size)
+ pid_n = (pid % width) // (group_size)
+ # do matrix multiplication
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k in tl.range(0, tl.cdiv(K, BLOCK_K)):
+ k_remaining = K - k * BLOCK_K
+ a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0)
+ b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0)
+
+ acc = tl.dot(a, b, acc)
+
+ A += BLOCK_K * stride_ak
+ B += BLOCK_K * stride_bk
+
+ scale_a = tl.load(Scale_A)
+ scale_b = tl.load(Scale_B)
+ scale_ab = scale_a.to(tl.float32) * scale_b.to(tl.float32)
+ # fp8 dequantize
+ acc = acc * scale_ab
+
+ if BIAS:
+ bias = tl.load(BIAS + rbn)
+ acc = acc + bias
+
+ # rematerialize rm and rn to save registers
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
+
+ if output_quantize:
+ acc = tl.reshape(acc, (BLOCK_M, BLOCK_N // QB, QB))
+ abs_acc = tl.abs(acc)
+ acc_max = tl.max(abs_acc, axis=2) + SCALE_MIN_THRES
+ # tl.device_print("acc_max", acc_max)
+ acc_scale = acc_max / fp8_max
+ # tl.device_print("acc_scale", acc_scale)
+ acc_scale = tl.reshape(acc_scale, (BLOCK_M, BLOCK_N // QB, 1))
+ acc = tl.fdiv(acc, acc_scale)
+ acc = tl.reshape(acc, (BLOCK_M, BLOCK_N))
+
+ if STOCHASTIC:
+ noise_block_ptr = noise_ptr + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
+ noise = tl.load(noise_block_ptr, boundary_check=(0, 1))
+ acc = _stochastic_rounding(acc, noise, e_bit, m_bit)
+
+ acc_scale = tl.reshape(acc_scale, (BLOCK_M, BLOCK_N // QB))
+ acc_scale = acc_scale.to(Scale_C.type.element_ty)
+ acc = acc.to(C.dtype.element_ty)
+
+ rsm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rsn = pid_n * BLOCK_N // QB + tl.arange(0, BLOCK_N // QB)
+ Scale_C = Scale_C + (rsm[:, None] * stride_scm + rsn[None, :] * stride_scn)
+
+ tl.store(C, acc, mask=mask)
+ tl.store(Scale_C, acc_scale)
+
+ else:
+ # handles write-back with reduction-splitting
+ acc = acc.to(C.dtype.element_ty)
+ tl.store(C, acc, mask=mask)
+
+
+def fp8matmul(a, b, output_quantize, scale_a, scale_b, QB, bias=None, stochastic=False):
+ # Deal with batched input
+ if len(a.shape) == 3:
+ BS, batched = a.shape[0], True
+ a = a.reshape(-1, a.shape[2])
+ else:
+ batched = False
+
+ # Check constraints.
+ assert a.shape[1] == b.shape[0], "Incompatible dimensions"
+ assert a.is_contiguous(), "Matrix A must be contiguous"
+ M, K = a.shape
+ K, N = b.shape
+ fp8MaxValue = FP8_MAX_VALUE[a.dtype] # E4M3 and E5M2 have different max value
+ e_bit, m_bit = convert_fp8_to_embit[a.dtype]
+
+ # Allocates output.
+ if output_quantize:
+ c = torch.empty((M, N), device=a.device, dtype=a.dtype)
+ # c = torch.empty((M, N), device=a.device, dtype=torch.float32)
+ scale_c = torch.empty((M, N // QB), device=a.device, dtype=torch.bfloat16)
+ else:
+ c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
+ scale_c = torch.empty(
+ (1, 1), device=a.device, dtype=torch.bfloat16
+ ) # This line is useless, equivalent to scale_c = None
+
+ if stochastic:
+ noise = torch.empty_like(c, dtype=torch.float32).uniform_(-0.5, 0.5)
+ else:
+ noise = None
+
+ # 1D launch kernel where each block gets its own program.
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+ _fp8matmul_kernel[grid](
+ a,
+ b,
+ c,
+ noise, #
+ M,
+ N,
+ K, #
+ a.stride(0),
+ a.stride(1), #
+ b.stride(0),
+ b.stride(1), #
+ c.stride(0),
+ c.stride(1), #
+ scale_a,
+ scale_b,
+ scale_c,
+ scale_c.stride(0),
+ scale_c.stride(1),
+ output_quantize=output_quantize,
+ QB=QB,
+ BIAS=bias,
+ fp8_max=fp8MaxValue,
+ e_bit=e_bit,
+ m_bit=m_bit,
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ STOCHASTIC=stochastic,
+ BLOCK_M=128,
+ BLOCK_N=256,
+ BLOCK_K=128,
+ GROUP_M=8,
+ num_stages=3,
+ num_warps=8,
+ )
+ # Reshape output to batch
+ if batched:
+ c = c.reshape(BS, -1, N)
+ if output_quantize:
+ scale_c = scale_c.reshape(BS, -1, N // QB)
+ return c, scale_c
+ else:
+ if output_quantize:
+ scale_c = scale_c.reshape(M, N // QB)
+ return c, scale_c
+ return c
+
+
+def fp8_linear_forward(x, s, w, s_w, output_quantize, QB, bias=None):
+ assert s.numel() == 1, f"X uses per-tensor quantization in linear forward, but the scale shape is {s.shape}"
+ assert s_w.numel() == 1, f"W uses per-tensor quantization in linear forward, but the scale shape is {s_w.shape}"
+
+ w_t = w.t()
+ return fp8matmul(x, w_t, output_quantize, s, s_w, QB, bias)
+
+
+# def fp8_linear_forward(x, s, w, s_w, output_quantize, QB):
+# print("you are using the wrong linear function. ")
+# w_t = w.t()
+# if output_quantize:
+# return fp8matmul(x, w_t, True, s, s_w, QB)
+# else:
+# y = fp8matmul(x, w_t, False, s, s_w, QB)
+
+# return y
+
+
+def fp8_linear_backward(
+ x_t, s, g, s_g, g_t, w_t, s_w, QB, bias=None, stochastic=False, dgrad_quantize=False
+): # dgrad_quantize=True for backward before flashattention
+ assert s.numel() == 1, f"X uses per-tensor quantization in linear backward, but the scale shape is {s.shape}"
+ assert s_g.numel() == 1, f"G uses per-tensor quantization in linear backward, but the scale shape is {s.shape}"
+ assert s_w.numel() == 1, f"W uses per-tensor quantization in linear backward, but the scale shape is {s_w.shape}"
+
+ batched = False
+ if len(g.shape) == 3: # others must be of 2D!
+ batched = True
+ BS = g.shape[0]
+ g = g.reshape(-1, g.shape[-1])
+
+ w_t_t = w_t.t()
+ x_t_t = x_t.t()
+ if dgrad_quantize:
+ y, s_y = fp8matmul(g, w_t_t, True, s_g, s_w, QB, stochastic=stochastic)
+ else:
+ y = fp8matmul(g, w_t_t, False, s_g, s_w, QB)
+
+ w_g = fp8matmul(g_t, x_t_t, False, s_g, s, QB)
+
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ if dgrad_quantize:
+ if s_y.numel() > 1:
+ s_y = s_y.reshape(BS, -1, s_y.shape[-1])
+ if dgrad_quantize:
+ return y, s_y, w_g
+ else:
+ return y, w_g
diff --git a/llava/model/coat/activation/real_quantization/mul_bwd.py b/llava/model/coat/activation/real_quantization/mul_bwd.py
new file mode 100644
index 0000000000000000000000000000000000000000..27f8b2839495d6d4f92fa35d537ea75e3e04c0c2
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/mul_bwd.py
@@ -0,0 +1,320 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from ._division_transpose import fp8_division_transpose
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block
+
+"""Element-wise Multiplication Backward"""
+"""Input1 (Gate) uses 1 * 16 group quantization"""
+"""Input2 (Up) uses 1 * 16 group quantization"""
+"""Grad (Down) uses full-precision/BF16"""
+"""Output1 (Gate) uses full-precision/BF16"""
+"""Output2 (Up) uses per-tensor quantization, we can choose whether it should be quantized inside this function"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_mul_backward_kernel(
+ output1_ptr, # output
+ output2_ptr,
+ output2_scale_ptr, # output
+ input1_ptr,
+ input1_scale_ptr, # input
+ input2_ptr,
+ input2_scale_ptr, # input
+ grad_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max,
+ e_bit,
+ m_bit, # shape
+ input1_stride_0,
+ input1_stride_1, # input1 stride
+ s_input1_stride_0,
+ s_input1_stride_1, # scale of input1 stride
+ input2_stride_0,
+ input2_stride_1, # input2 stride
+ s_input2_stride_0,
+ s_input2_stride_1, # scale of input2 stride
+ grad_stride_0,
+ grad_stride_1, # input stride
+ output1_stride_0,
+ output1_stride_1, # output stride
+ output2_stride_0,
+ output2_stride_1, # output stride
+ s_output2_stride_0,
+ s_output2_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ STOCHASTIC: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # --- The first input ---
+ input1_block_ptr = tl.make_block_ptr(
+ base=input1_ptr,
+ shape=(M, N),
+ strides=(input1_stride_0, input1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input1_ptr = tl.make_block_ptr(
+ base=input1_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input1_stride_0, s_input1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input1 = tl.load(input1_block_ptr)
+ scale_input1 = tl.load(scale_input1_ptr)
+
+ input1 = input1.to(tl.float32)
+ scale_input1 = scale_input1.to(tl.float32)
+
+ # Dequantize and mul calculation
+ scale_input1 = tl.reshape(scale_input1, (BLOCK_M, BLOCK_SN, 1))
+ input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB))
+ input1 = input1 * scale_input1
+
+ # --- The second input ---
+ input2_block_ptr = tl.make_block_ptr(
+ base=input2_ptr,
+ shape=(M, N),
+ strides=(input2_stride_0, input2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input2_ptr = tl.make_block_ptr(
+ base=input2_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input2_stride_0, s_input2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input2 = tl.load(input2_block_ptr)
+ scale_input2 = tl.load(scale_input2_ptr)
+
+ input2 = input2.to(tl.float32)
+ scale_input2 = scale_input2.to(tl.float32)
+
+ # Dequantize and mul calculation
+ scale_input2 = tl.reshape(scale_input2, (BLOCK_M, BLOCK_SN, 1))
+ input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB))
+ input2 = input2 * scale_input2
+
+ # pointers of gradient
+ grad_block_ptr = tl.make_block_ptr(
+ base=grad_ptr,
+ shape=(M, N),
+ strides=(grad_stride_0, grad_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ grad = tl.load(grad_block_ptr)
+ grad = grad.to(tl.float32)
+ grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB))
+
+ # Actual Calculation of Mul Backward
+ grad1 = grad * input2
+ grad1 = tl.reshape(grad1, (BLOCK_M, BLOCK_N))
+ grad1 = grad1.to(output1_ptr.type.element_ty)
+
+ # pointers
+ output1_block_ptr = tl.make_block_ptr(
+ base=output1_ptr,
+ shape=(M, N),
+ strides=(output1_stride_0, output1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ tl.store(output1_block_ptr, grad1, boundary_check=(0, 1))
+
+ # Actual Calculation of Mul Backward
+ grad2 = grad * input1
+ # Quantize the grad 1 - Scale calculation
+ abs_grad2 = tl.abs(grad2)
+ max_val = tl.max(abs_grad2, axis=2) + SCALE_MIN_THRES
+ scale_grad2 = max_val / fp8_max
+ scale_grad2 = tl.reshape(scale_grad2, (BLOCK_M, BLOCK_SN, 1))
+ # Quantize
+ # grad1 = tl.fdiv(grad1, scale_output) # do not quantize the output due to the data flow
+ grad2 = grad2.to(output2_ptr.type.element_ty)
+ scale_grad2 = scale_grad2.to(output2_scale_ptr.type.element_ty)
+ scale_grad2 = tl.reshape(scale_grad2, (BLOCK_M, BLOCK_SN))
+ grad2 = tl.reshape(grad2, (BLOCK_M, BLOCK_N))
+
+ # pointers
+ output2_block_ptr = tl.make_block_ptr(
+ base=output2_ptr,
+ shape=(M, N),
+ strides=(output2_stride_0, output2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output2_ptr = tl.make_block_ptr(
+ base=output2_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output2_stride_0, s_output2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+ tl.store(output2_block_ptr, grad2, boundary_check=(0, 1))
+ tl.store(scale_output2_ptr, scale_grad2, boundary_check=(0, 1))
+
+
+def fp8_mul_backward(
+ x1, s_x1, x2, s_x2, g, QB, fp8type, stochastic=False, output_quantized_transpose=False
+): # Stochastic Rounding is left outside this function
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x1.shape) == 3:
+ assert len(s_x1.shape) == 3
+ batched = True
+ BS = x1.shape[0]
+ x1 = x1.reshape(-1, x1.shape[-1])
+ s_x1 = s_x1.reshape(-1, s_x1.shape[-1])
+ x2 = x2.reshape(-1, x2.shape[-1])
+ s_x2 = s_x2.reshape(-1, s_x2.shape[-1])
+ g = g.reshape(-1, g.shape[-1])
+
+ if stochastic:
+ noise = torch.empty_like(g, dtype=torch.float32).uniform_(-0.5, 0.5)
+ else:
+ noise = None
+
+ # defining the input and output tensor
+ M, N = x1.shape
+ _, SN = s_x1.shape # assume the shape of quantization block size is always 1 * G
+ assert x1.shape == x2.shape
+ assert s_x1.shape == s_x2.shape
+
+ if isinstance(fp8type, str):
+ fp8type = convert_str_to_fp8[fp8type]
+ y1 = torch.empty_like(g, dtype=torch.bfloat16)
+ y2 = torch.empty_like(g, dtype=torch.bfloat16)
+ s_y2 = torch.empty_like(s_x1, dtype=torch.bfloat16)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+ e_bit, m_bit = convert_fp8_to_embit[fp8type]
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_mul_backward_kernel[grid](
+ y1,
+ y2,
+ s_y2,
+ x1,
+ s_x1,
+ x2,
+ s_x2,
+ g,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ e_bit,
+ m_bit,
+ x1.stride(0),
+ x1.stride(1),
+ s_x1.stride(0),
+ s_x1.stride(1),
+ x2.stride(0),
+ x2.stride(1),
+ s_x2.stride(0),
+ s_x2.stride(1),
+ g.stride(0),
+ g.stride(1),
+ y1.stride(0),
+ y1.stride(1),
+ y2.stride(0),
+ y2.stride(1),
+ s_y2.stride(0),
+ s_y2.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ STOCHASTIC=stochastic,
+ )
+
+ if not output_quantized_transpose:
+ # Recover 2D to 3D
+ if batched:
+ y1 = y1.reshape(BS, -1, y1.shape[-1])
+
+ y2 = y2.reshape(BS, -1, y2.shape[-1])
+ s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1])
+
+ return y1, (y2, s_y2)
+ else:
+ # Per-tensor quantization
+ s_y2_max = s_y2.max()
+ qy2, s_y2_max, qy2_t = fp8_division_transpose(y2, QB, fp8type, s_y2_max)
+
+ # Recover 2D to 3D
+ if batched:
+ y1 = y1.reshape(BS, -1, y1.shape[-1])
+
+ y2 = y2.reshape(BS, -1, y2.shape[-1])
+ qy2 = qy2.reshape(BS, -1, qy2.shape[-1])
+ s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1])
+
+ return y1, (qy2, s_y2_max, qy2_t)
diff --git a/llava/model/coat/activation/real_quantization/mul_bwd_legacy.py b/llava/model/coat/activation/real_quantization/mul_bwd_legacy.py
new file mode 100644
index 0000000000000000000000000000000000000000..803789e83cb27330d6195a515805440af45cc528
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/mul_bwd_legacy.py
@@ -0,0 +1,374 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from ._division import _stochastic_rounding
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, get_configs_io_block
+
+"""Element-wise Multiplication Backward"""
+"""Input1 (Gate) uses 1 * 16 group quantization"""
+"""Input2 (Up) uses 1 * 16 group quantization"""
+"""Grad (Down) uses 1 * 16 group quantization"""
+"""Output1 (Gate) uses 1 * 16 quantization"""
+"""Output2 (Up) uses per-tensor quantization, but should be quantized outside this function""" # Although it is per-tensor quantization, we only apply per-group quantization here, and the reduction should be performed outside this function.
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_mul_backward_legacy_kernel(
+ output1_ptr,
+ output1_scale_ptr, # output
+ output2_ptr,
+ output2_scale_ptr, # output
+ input1_ptr,
+ input1_scale_ptr, # input
+ input2_ptr,
+ input2_scale_ptr, # input
+ grad_ptr,
+ grad_scale_ptr, # input
+ noise_ptr, # noise for stochastic
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max,
+ e_bit,
+ m_bit, # shape
+ input1_stride_0,
+ input1_stride_1, # input1 stride
+ s_input1_stride_0,
+ s_input1_stride_1, # scale of input1 stride
+ input2_stride_0,
+ input2_stride_1, # input2 stride
+ s_input2_stride_0,
+ s_input2_stride_1, # scale of input2 stride
+ grad_stride_0,
+ grad_stride_1, # input stride
+ s_grad_stride_0,
+ s_grad_stride_1, # scale of input stride
+ output1_stride_0,
+ output1_stride_1, # output stride
+ s_output1_stride_0,
+ s_output1_stride_1, # scale of output stride
+ output2_stride_0,
+ output2_stride_1, # output stride
+ s_output2_stride_0,
+ s_output2_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ STOCHASTIC: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # --- The first input ---
+ input1_block_ptr = tl.make_block_ptr(
+ base=input1_ptr,
+ shape=(M, N),
+ strides=(input1_stride_0, input1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input1_ptr = tl.make_block_ptr(
+ base=input1_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input1_stride_0, s_input1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input1 = tl.load(input1_block_ptr)
+ scale_input1 = tl.load(scale_input1_ptr)
+
+ input1 = input1.to(tl.float32)
+ scale_input1 = scale_input1.to(tl.float32)
+
+ # Dequantize and mul calculation
+ scale_input1 = tl.reshape(scale_input1, (BLOCK_M, BLOCK_SN, 1))
+ input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB))
+ input1 = input1 * scale_input1
+
+ # --- The second input ---
+ input2_block_ptr = tl.make_block_ptr(
+ base=input2_ptr,
+ shape=(M, N),
+ strides=(input2_stride_0, input2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input2_ptr = tl.make_block_ptr(
+ base=input2_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input2_stride_0, s_input2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input2 = tl.load(input2_block_ptr)
+ scale_input2 = tl.load(scale_input2_ptr)
+
+ input2 = input2.to(tl.float32)
+ scale_input2 = scale_input2.to(tl.float32)
+
+ # Dequantize and mul calculation
+ scale_input2 = tl.reshape(scale_input2, (BLOCK_M, BLOCK_SN, 1))
+ input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB))
+ input2 = input2 * scale_input2
+
+ # pointers of gradient
+ grad_block_ptr = tl.make_block_ptr(
+ base=grad_ptr,
+ shape=(M, N),
+ strides=(grad_stride_0, grad_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # grad ptr
+ scale_grad_ptr = tl.make_block_ptr(
+ base=grad_scale_ptr,
+ shape=(M, SN),
+ strides=(s_grad_stride_0, s_grad_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ grad = tl.load(grad_block_ptr)
+ scale_grad = tl.load(scale_grad_ptr)
+
+ grad = grad.to(tl.float32)
+ scale_grad = scale_grad.to(tl.float32)
+
+ # Dequantize and swish calculation
+ scale_grad = tl.reshape(scale_grad, (BLOCK_M, BLOCK_SN, 1))
+ grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB))
+ grad = grad * scale_grad
+
+ # Actual Calculation of Mul Backward
+ grad1 = grad * input2
+ # Quantize the grad 1 - Scale calculation
+ abs_grad1 = tl.abs(grad1)
+ max_val = tl.max(abs_grad1, axis=2) + SCALE_MIN_THRES
+ scale_grad1 = max_val / fp8_max
+ scale_grad1 = tl.reshape(scale_grad1, (BLOCK_M, BLOCK_SN, 1))
+ # Quantize
+ grad1 = tl.fdiv(grad1, scale_grad1) # do not quantize the output due to the data flow
+ scale_grad1 = scale_grad1.to(output1_scale_ptr.type.element_ty)
+ scale_grad1 = tl.reshape(scale_grad1, (BLOCK_M, BLOCK_SN))
+ grad1 = tl.reshape(grad1, (BLOCK_M, BLOCK_N))
+
+ if STOCHASTIC:
+ # noise_block_ptr = tl.make_block_ptr(
+ # base=noise_ptr,
+ # shape=(M, N),
+ # strides=(input1_stride_0, input1_stride_1),
+ # offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ # block_shape=(BLOCK_M, BLOCK_N),
+ # order=(1, 0)
+ # )
+ # noise = tl.load(noise_block_ptr)
+
+ offs_m = pid_dim0 * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_dim1 * BLOCK_N + tl.arange(0, BLOCK_N)
+ noise_offset = offs_m[:, None] * input1_stride_0 + offs_n[None, :] * input1_stride_1
+ noise = tl.rand(0, noise_offset)
+
+ grad1 = _stochastic_rounding(grad1, noise, e_bit, m_bit)
+
+ grad1 = grad1.to(output1_ptr.type.element_ty)
+
+ # pointers
+ output1_block_ptr = tl.make_block_ptr(
+ base=output1_ptr,
+ shape=(M, N),
+ strides=(output1_stride_0, output1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output1_ptr = tl.make_block_ptr(
+ base=output1_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output1_stride_0, s_output1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+ tl.store(output1_block_ptr, grad1, boundary_check=(0, 1))
+ tl.store(scale_output1_ptr, scale_grad1, boundary_check=(0, 1))
+
+ # Actual Calculation of Mul Backward
+ grad2 = grad * input1
+ # Quantize the grad 1 - Scale calculation
+ abs_grad2 = tl.abs(grad2)
+ max_val = tl.max(abs_grad2, axis=2) + SCALE_MIN_THRES
+ scale_grad2 = max_val / fp8_max
+ scale_grad2 = tl.reshape(scale_grad2, (BLOCK_M, BLOCK_SN, 1))
+ # Quantize
+ # grad1 = tl.fdiv(grad1, scale_output) # do not quantize the output due to the data flow
+ grad2 = grad2.to(output2_ptr.type.element_ty)
+ scale_grad2 = scale_grad2.to(output2_scale_ptr.type.element_ty)
+ scale_grad2 = tl.reshape(scale_grad2, (BLOCK_M, BLOCK_SN))
+ grad2 = tl.reshape(grad2, (BLOCK_M, BLOCK_N))
+
+ # pointers
+ output2_block_ptr = tl.make_block_ptr(
+ base=output2_ptr,
+ shape=(M, N),
+ strides=(output2_stride_0, output2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output2_ptr = tl.make_block_ptr(
+ base=output2_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output2_stride_0, s_output2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+ tl.store(output2_block_ptr, grad2, boundary_check=(0, 1))
+ tl.store(scale_output2_ptr, scale_grad2, boundary_check=(0, 1))
+
+
+def fp8_mul_backward_legacy(
+ x1, s_x1, x2, s_x2, g, s_g, QB, stochastic=False
+): # Stochastic Rounding is left outside this function
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x1.shape) == 3:
+ assert len(s_x1.shape) == 3
+ batched = True
+ BS = x1.shape[0]
+ x1 = x1.reshape(-1, x1.shape[-1])
+ s_x1 = s_x1.reshape(-1, s_x1.shape[-1])
+ x2 = x2.reshape(-1, x2.shape[-1])
+ s_x2 = s_x2.reshape(-1, s_x2.shape[-1])
+ g = g.reshape(-1, g.shape[-1])
+ s_g = s_g.reshape(-1, s_g.shape[-1])
+
+ if stochastic:
+ noise = torch.empty_like(g, dtype=torch.float32).uniform_(-0.5, 0.5)
+ else:
+ noise = None
+
+ # defining the input and output tensor
+ M, N = x1.shape
+ _, SN = s_x1.shape # assume the shape of quantization block size is always 1 * G
+ assert x1.shape == x2.shape
+ assert s_x1.shape == s_x2.shape
+
+ y1 = torch.empty_like(g, dtype=g.dtype)
+ s_y1 = torch.empty_like(s_g, dtype=s_g.dtype)
+ y2 = torch.empty_like(g, dtype=torch.bfloat16)
+ s_y2 = torch.empty_like(s_g, dtype=s_g.dtype)
+ fp8MaxValue = FP8_MAX_VALUE[g.dtype] # E4M3 and E5M2 have different max value
+ e_bit, m_bit = convert_fp8_to_embit[g.dtype]
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_mul_backward_legacy_kernel[grid](
+ y1,
+ s_y1,
+ y2,
+ s_y2,
+ x1,
+ s_x1,
+ x2,
+ s_x2,
+ g,
+ s_g,
+ noise,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ e_bit,
+ m_bit,
+ x1.stride(0),
+ x1.stride(1),
+ s_x1.stride(0),
+ s_x1.stride(1),
+ x2.stride(0),
+ x2.stride(1),
+ s_x2.stride(0),
+ s_x2.stride(1),
+ g.stride(0),
+ g.stride(1),
+ s_g.stride(0),
+ s_g.stride(1),
+ y1.stride(0),
+ y1.stride(1),
+ s_y1.stride(0),
+ s_y1.stride(1),
+ y2.stride(0),
+ y2.stride(1),
+ s_y2.stride(0),
+ s_y2.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ STOCHASTIC=stochastic,
+ )
+
+ # Recover 2D to 3D
+ if batched:
+ y1 = y1.reshape(BS, -1, y1.shape[-1])
+ s_y1 = s_y1.reshape(BS, -1, s_y1.shape[-1])
+
+ y2 = y2.reshape(BS, -1, y2.shape[-1])
+ s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1])
+
+ return y1, s_y1, y2, s_y2
diff --git a/llava/model/coat/activation/real_quantization/mul_bwd_silu_fwd.py b/llava/model/coat/activation/real_quantization/mul_bwd_silu_fwd.py
new file mode 100644
index 0000000000000000000000000000000000000000..d23c26f6c86284b8d614ae0f90819bc2e4bc1718
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/mul_bwd_silu_fwd.py
@@ -0,0 +1,337 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from ._division_transpose import fp8_division_transpose
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_fp8_to_embit, convert_str_to_fp8, get_configs_io_block
+
+"""Element-wise Multiplication Backward"""
+"""Input1 (Gate) uses 1 * 16 group quantization"""
+"""Input2 (Up) uses 1 * 16 group quantization"""
+"""Grad (Down) uses full-precision/BF16"""
+"""Output1 (Gate) uses full-precision/BF16"""
+"""Output2 (Up) uses per-tensor quantization, we can choose whether it should be quantized inside this function"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_mul_backward_silu_forward_kernel(
+ output1_ptr, # output
+ output2_ptr,
+ output2_scale_ptr, # output
+ input1_ptr,
+ input1_scale_ptr, # input
+ input2_ptr,
+ input2_scale_ptr, # input
+ grad_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max,
+ e_bit,
+ m_bit, # shape
+ input1_stride_0,
+ input1_stride_1, # input1 stride
+ s_input1_stride_0,
+ s_input1_stride_1, # scale of input1 stride
+ input2_stride_0,
+ input2_stride_1, # input2 stride
+ s_input2_stride_0,
+ s_input2_stride_1, # scale of input2 stride
+ grad_stride_0,
+ grad_stride_1, # input stride
+ output1_stride_0,
+ output1_stride_1, # output stride
+ output2_stride_0,
+ output2_stride_1, # output stride
+ s_output2_stride_0,
+ s_output2_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ STOCHASTIC: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # --- The first input ---
+ input1_block_ptr = tl.make_block_ptr(
+ base=input1_ptr,
+ shape=(M, N),
+ strides=(input1_stride_0, input1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input1_ptr = tl.make_block_ptr(
+ base=input1_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input1_stride_0, s_input1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input1 = tl.load(input1_block_ptr)
+ scale_input1 = tl.load(scale_input1_ptr)
+
+ input1 = input1.to(tl.float32)
+ scale_input1 = scale_input1.to(tl.float32)
+
+ # Dequantize and mul calculation
+ scale_input1 = tl.reshape(scale_input1, (BLOCK_M, BLOCK_SN, 1))
+ input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB))
+ input1 = input1 * scale_input1
+
+ # --- recompute SiLU ---
+ # Actual Calculation of SiLU
+ sigmoid = 1 / (1.0 + libdevice.exp(-input1))
+ silu_output = sigmoid * input1
+
+ # Quantize Scale calculation
+ abs_output = tl.abs(silu_output)
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
+
+ # Quantize
+ silu_output = tl.fdiv(silu_output, scale_output)
+ silu_output = silu_output.to(input1_ptr.type.element_ty)
+
+ input1 = silu_output * scale_output
+
+ # --- The second input ---
+ input2_block_ptr = tl.make_block_ptr(
+ base=input2_ptr,
+ shape=(M, N),
+ strides=(input2_stride_0, input2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input2_ptr = tl.make_block_ptr(
+ base=input2_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input2_stride_0, s_input2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input2 = tl.load(input2_block_ptr)
+ scale_input2 = tl.load(scale_input2_ptr)
+
+ input2 = input2.to(tl.float32)
+ scale_input2 = scale_input2.to(tl.float32)
+
+ # Dequantize and mul calculation
+ scale_input2 = tl.reshape(scale_input2, (BLOCK_M, BLOCK_SN, 1))
+ input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB))
+ input2 = input2 * scale_input2
+
+ # pointers of gradient
+ grad_block_ptr = tl.make_block_ptr(
+ base=grad_ptr,
+ shape=(M, N),
+ strides=(grad_stride_0, grad_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ grad = tl.load(grad_block_ptr)
+ grad = grad.to(tl.float32)
+ grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB))
+
+ # Actual Calculation of Mul Backward
+ grad1 = grad * input2
+ grad1 = tl.reshape(grad1, (BLOCK_M, BLOCK_N))
+ grad1 = grad1.to(output1_ptr.type.element_ty)
+
+ # pointers
+ output1_block_ptr = tl.make_block_ptr(
+ base=output1_ptr,
+ shape=(M, N),
+ strides=(output1_stride_0, output1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ tl.store(output1_block_ptr, grad1, boundary_check=(0, 1))
+
+ # Actual Calculation of Mul Backward
+ grad2 = grad * input1
+ # Quantize the grad 1 - Scale calculation
+ abs_grad2 = tl.abs(grad2)
+ max_val = tl.max(abs_grad2, axis=2) + SCALE_MIN_THRES
+ scale_grad2 = max_val / fp8_max
+ scale_grad2 = tl.reshape(scale_grad2, (BLOCK_M, BLOCK_SN, 1))
+ # Quantize
+ # grad1 = tl.fdiv(grad1, scale_output) # do not quantize the output due to the data flow
+ grad2 = grad2.to(output2_ptr.type.element_ty)
+ scale_grad2 = scale_grad2.to(output2_scale_ptr.type.element_ty)
+ scale_grad2 = tl.reshape(scale_grad2, (BLOCK_M, BLOCK_SN))
+ grad2 = tl.reshape(grad2, (BLOCK_M, BLOCK_N))
+
+ # pointers
+ output2_block_ptr = tl.make_block_ptr(
+ base=output2_ptr,
+ shape=(M, N),
+ strides=(output2_stride_0, output2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output2_ptr = tl.make_block_ptr(
+ base=output2_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output2_stride_0, s_output2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+ tl.store(output2_block_ptr, grad2, boundary_check=(0, 1))
+ tl.store(scale_output2_ptr, scale_grad2, boundary_check=(0, 1))
+
+
+def fp8_mul_backward_silu_forward(
+ x1, s_x1, x2, s_x2, g, QB, fp8type, stochastic=False, output_quantized_transpose=False
+): # Stochastic Rounding is left outside this function
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x1.shape) == 3:
+ assert len(s_x1.shape) == 3
+ batched = True
+ BS = x1.shape[0]
+ x1 = x1.reshape(-1, x1.shape[-1])
+ s_x1 = s_x1.reshape(-1, s_x1.shape[-1])
+ x2 = x2.reshape(-1, x2.shape[-1])
+ s_x2 = s_x2.reshape(-1, s_x2.shape[-1])
+ g = g.reshape(-1, g.shape[-1])
+
+ if stochastic:
+ noise = torch.empty_like(g, dtype=torch.float32).uniform_(-0.5, 0.5)
+ else:
+ noise = None
+
+ # defining the input and output tensor
+ M, N = x1.shape
+ _, SN = s_x1.shape # assume the shape of quantization block size is always 1 * G
+ assert x1.shape == x2.shape
+ assert s_x1.shape == s_x2.shape
+
+ if isinstance(fp8type, str):
+ fp8type = convert_str_to_fp8[fp8type]
+ y1 = torch.empty_like(g, dtype=torch.bfloat16)
+ y2 = torch.empty_like(g, dtype=torch.bfloat16)
+ s_y2 = torch.empty_like(s_x1, dtype=torch.bfloat16)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+ e_bit, m_bit = convert_fp8_to_embit[fp8type]
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_mul_backward_silu_forward_kernel[grid](
+ y1,
+ y2,
+ s_y2,
+ x1,
+ s_x1,
+ x2,
+ s_x2,
+ g,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ e_bit,
+ m_bit,
+ x1.stride(0),
+ x1.stride(1),
+ s_x1.stride(0),
+ s_x1.stride(1),
+ x2.stride(0),
+ x2.stride(1),
+ s_x2.stride(0),
+ s_x2.stride(1),
+ g.stride(0),
+ g.stride(1),
+ y1.stride(0),
+ y1.stride(1),
+ y2.stride(0),
+ y2.stride(1),
+ s_y2.stride(0),
+ s_y2.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ STOCHASTIC=stochastic,
+ )
+
+ if not output_quantized_transpose:
+ # Recover 2D to 3D
+ if batched:
+ y1 = y1.reshape(BS, -1, y1.shape[-1])
+
+ y2 = y2.reshape(BS, -1, y2.shape[-1])
+ s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1])
+
+ return y1, (y2, s_y2)
+ else:
+ # Per-tensor quantization
+ s_y2_max = s_y2.max()
+ qy2, s_y2_max, qy2_t = fp8_division_transpose(y2, QB, fp8type, s_y2_max)
+
+ # Recover 2D to 3D
+ if batched:
+ y1 = y1.reshape(BS, -1, y1.shape[-1])
+
+ y2 = y2.reshape(BS, -1, y2.shape[-1])
+ qy2 = qy2.reshape(BS, -1, qy2.shape[-1])
+ s_y2 = s_y2.reshape(BS, -1, s_y2.shape[-1])
+
+ return y1, (qy2, s_y2_max, qy2_t)
diff --git a/llava/model/coat/activation/real_quantization/mul_fwd.py b/llava/model/coat/activation/real_quantization/mul_fwd.py
new file mode 100644
index 0000000000000000000000000000000000000000..0037c7220c6f1bb0b7b9b5946a7d2553b44c7426
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/mul_fwd.py
@@ -0,0 +1,260 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from ._division_transpose import fp8_division_transpose
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block
+
+"""Element-wise Multiplication Forward"""
+"""Input1 (Gate) uses 1 * 16 group quantization"""
+"""Input2 (Up) uses 1 * 16 group quantization"""
+"""Output uses per-tensor quantization"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+fp8_max_value = {
+ torch.float8_e4m3fn: 448,
+ torch.float8_e5m2: 57344,
+}
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def fp8_mul_forward_kernel(
+ output_ptr,
+ output_scale_ptr, # output
+ input1_ptr,
+ input1_scale_ptr, # input
+ input2_ptr,
+ input2_scale_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input1_stride_0,
+ input1_stride_1, # input1 stride
+ s_input1_stride_0,
+ s_input1_stride_1, # scale of input1 stride
+ input2_stride_0,
+ input2_stride_1, # input2 stride
+ s_input2_stride_0,
+ s_input2_stride_1, # scale of input2 stride
+ output_stride_0,
+ output_stride_1, # output stride
+ s_output_stride_0,
+ s_output_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # --- The first input ---
+ input1_block_ptr = tl.make_block_ptr(
+ base=input1_ptr,
+ shape=(M, N),
+ strides=(input1_stride_0, input1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input1_ptr = tl.make_block_ptr(
+ base=input1_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input1_stride_0, s_input1_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input1 = tl.load(input1_block_ptr)
+ scale_input1 = tl.load(scale_input1_ptr)
+
+ input1 = input1.to(tl.float32)
+ scale_input1 = scale_input1.to(tl.float32)
+
+ # Dequantize and mul calculation
+ scale_input1 = tl.reshape(scale_input1, (BLOCK_M, BLOCK_SN, 1))
+ input1 = tl.reshape(input1, (BLOCK_M, BLOCK_SN, QB))
+ input1 = input1 * scale_input1
+
+ # --- The second input ---
+ input2_block_ptr = tl.make_block_ptr(
+ base=input2_ptr,
+ shape=(M, N),
+ strides=(input2_stride_0, input2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input2_ptr = tl.make_block_ptr(
+ base=input2_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input2_stride_0, s_input2_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input2 = tl.load(input2_block_ptr)
+ scale_input2 = tl.load(scale_input2_ptr)
+
+ input2 = input2.to(tl.float32)
+ scale_input2 = scale_input2.to(tl.float32)
+
+ # Dequantize and mul calculation
+ scale_input2 = tl.reshape(scale_input2, (BLOCK_M, BLOCK_SN, 1))
+ input2 = tl.reshape(input2, (BLOCK_M, BLOCK_SN, QB))
+ input2 = input2 * scale_input2
+
+ # Actual Calculation of SiLU
+ mul_output = input1 * input2
+
+ # Quantize Scale calculation
+ abs_output = tl.abs(mul_output)
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
+
+ # Quantize
+ # mul_output = tl.fdiv(mul_output, scale_output) # do not quantize the output since it should use per-tensor quantization afterwards
+ mul_output = mul_output.to(output_ptr.type.element_ty)
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
+ mul_output = tl.reshape(mul_output, (BLOCK_M, BLOCK_N))
+
+ # debug
+ # mul_output = input
+ # scale_output = scale_input
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output_ptr = tl.make_block_ptr(
+ base=output_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output_stride_0, s_output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ tl.store(output_block_ptr, mul_output, boundary_check=(0, 1))
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
+
+
+def fp8_mul_forward(x1, s_x1, x2, s_x2, QB, transpose_output_2d=False):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x1.shape) == 3:
+ assert len(s_x1.shape) == 3
+ batched = True
+ BS = x1.shape[0]
+ x1 = x1.reshape(-1, x1.shape[-1])
+ s_x1 = s_x1.reshape(-1, s_x1.shape[-1])
+ x2 = x2.reshape(-1, x2.shape[-1])
+ s_x2 = s_x2.reshape(-1, s_x2.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x1.shape
+ _, SN = s_x1.shape # assume the shape of quantization block size is always 1 * G
+ assert x1.shape == x2.shape
+ assert s_x1.shape == s_x2.shape
+
+ y = torch.empty_like(x1, dtype=torch.bfloat16)
+ s_y = torch.empty_like(s_x1, dtype=s_x1.dtype)
+ fp8MaxValue = fp8_max_value[x1.dtype] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ fp8_mul_forward_kernel[grid](
+ y,
+ s_y,
+ x1,
+ s_x1,
+ x2,
+ s_x2,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x1.stride(0),
+ x1.stride(1),
+ s_x1.stride(0),
+ s_x1.stride(1),
+ x2.stride(0),
+ x2.stride(1),
+ s_x2.stride(0),
+ s_x2.stride(1),
+ y.stride(0),
+ y.stride(1),
+ s_y.stride(0),
+ s_y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ s_y_max = s_y.max()
+ qy, s_y_max, qy_t = fp8_division_transpose(y, QB, x2.dtype, s_y_max)
+ qy = qy.to(x2.dtype)
+ qy_t = qy_t.to(x2.dtype)
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ qy = qy.reshape(BS, -1, qy.shape[-1])
+ if not transpose_output_2d:
+ qy_t = qy_t.reshape(BS, -1, qy.shape[-1])
+
+ return qy, s_y_max, qy_t
diff --git a/llava/model/coat/activation/real_quantization/silu_bwd.py b/llava/model/coat/activation/real_quantization/silu_bwd.py
new file mode 100644
index 0000000000000000000000000000000000000000..052bf66e6af6f5f3ce40e08e7b5c8c92e97bfbe3
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/silu_bwd.py
@@ -0,0 +1,255 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from ._division_transpose import fp8_division_transpose
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, convert_str_to_fp8, get_configs_io_block
+
+"""SiLU Activation Backward"""
+"""Input uses 1 * 16 group quantization"""
+"""Grad uses full-precision / BF16"""
+"""Output uses per-tensor quantization, we can choose whether it should be quantized inside this function"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_silu_backward_kernel(
+ output_ptr,
+ output_scale_ptr, # output
+ input_ptr,
+ input_scale_ptr, # input
+ grad_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ s_input_stride_0,
+ s_input_stride_1, # scale of input stride
+ grad_stride_0,
+ grad_stride_1, # input stride
+ output_stride_0,
+ output_stride_1, # output stride
+ s_output_stride_0,
+ s_output_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input_ptr = tl.make_block_ptr(
+ base=input_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input_stride_0, s_input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ scale_input = tl.load(scale_input_ptr)
+
+ input = input.to(tl.float32)
+ scale_input = scale_input.to(tl.float32)
+
+ # Dequantize and silu calculation
+ scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1))
+ input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+ input = input * scale_input
+
+ # pointers of gradient
+ grad_block_ptr = tl.make_block_ptr(
+ base=grad_ptr,
+ shape=(M, N),
+ strides=(grad_stride_0, grad_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ grad = tl.load(grad_block_ptr)
+ grad = grad.to(tl.float32)
+ grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB))
+
+ # Actual Calculation of SiLU's backward
+ sigmoid = 1 / (1.0 + libdevice.exp(-input))
+ silu_output = sigmoid + input * sigmoid * (1 - sigmoid)
+ silu_output = silu_output * grad
+
+ # Quantize Scale calculation
+ abs_output = tl.abs(silu_output)
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
+
+ # Quantize
+ # silu_output = tl.fdiv(silu_output, scale_output)
+ silu_output = silu_output.to(output_ptr.type.element_ty)
+
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
+ silu_output = tl.reshape(silu_output, (BLOCK_M, BLOCK_N))
+
+ # debug
+ # silu_output = input
+ # scale_output = scale_input
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output_ptr = tl.make_block_ptr(
+ base=output_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output_stride_0, s_output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ tl.store(output_block_ptr, silu_output, boundary_check=(0, 1))
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
+
+
+def fp8_silu_backward(
+ x, s_x, g, QB, fp8type, stochastic=False, output_quantized_transpose=False
+): # Stochastic Rounding is left outside this function
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ assert len(s_x.shape) == 3
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+ s_x = s_x.reshape(-1, s_x.shape[-1])
+ g = g.reshape(-1, g.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x.shape
+ _, SN = s_x.shape # assume the shape of quantization block size is always 1 * G
+
+ if isinstance(fp8type, str):
+ fp8type = convert_str_to_fp8[fp8type]
+
+ y = torch.empty_like(g, dtype=torch.bfloat16)
+ s_y = torch.empty_like(s_x, dtype=torch.bfloat16)
+ fp8MaxValue = FP8_MAX_VALUE[fp8type] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_silu_backward_kernel[grid](
+ y,
+ s_y,
+ x,
+ s_x,
+ g,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x.stride(0),
+ x.stride(1),
+ s_x.stride(0),
+ s_x.stride(1),
+ g.stride(0),
+ g.stride(1),
+ y.stride(0),
+ y.stride(1),
+ s_y.stride(0),
+ s_y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ if not output_quantized_transpose:
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ s_y = s_y.reshape(BS, -1, s_y.shape[-1])
+
+ return y, s_y
+ else:
+ # Per-tensor quantization
+ s_y_max = s_y.max()
+ qy, s_y_max, qy_t = fp8_division_transpose(y, QB, fp8type, s_y_max)
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ qy = qy.reshape(BS, -1, qy.shape[-1])
+ s_y = s_y.reshape(BS, -1, s_y.shape[-1])
+
+ return qy, s_y_max, qy_t
+
+ # # Per-tensor quantization
+ # s_y_max = s_y.max()
+ # qy, s_y_max = fp8_division(y, QB, fp8type, s_y_max)
+
+ # # Recover 2D to 3D
+ # if batched:
+ # y = y.reshape(BS, -1, y.shape[-1])
+ # qy = qy.reshape(BS, -1, qy.shape[-1])
+ # s_y = s_y.reshape(BS, -1, s_y.shape[-1])
+
+ # return qy, s_y_max
diff --git a/llava/model/coat/activation/real_quantization/silu_bwd_legacy.py b/llava/model/coat/activation/real_quantization/silu_bwd_legacy.py
new file mode 100644
index 0000000000000000000000000000000000000000..88de423ad973f34718f8d44e1d2e37b3113e1737
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/silu_bwd_legacy.py
@@ -0,0 +1,248 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block
+
+"""SiLU Activation Backward"""
+"""Input uses 1 * 16 group quantization"""
+"""Grad uses 1 * 16 group quantization"""
+"""Output uses per-tensor quantization, but should be quantized outside this function"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_silu_backward_legacy_kernel(
+ output_ptr,
+ output_scale_ptr, # output
+ input_ptr,
+ input_scale_ptr, # input
+ grad_ptr,
+ grad_scale_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ s_input_stride_0,
+ s_input_stride_1, # scale of input stride
+ grad_stride_0,
+ grad_stride_1, # input stride
+ s_grad_stride_0,
+ s_grad_stride_1, # scale of input stride
+ output_stride_0,
+ output_stride_1, # output stride
+ s_output_stride_0,
+ s_output_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input_ptr = tl.make_block_ptr(
+ base=input_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input_stride_0, s_input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ scale_input = tl.load(scale_input_ptr)
+
+ input = input.to(tl.float32)
+ scale_input = scale_input.to(tl.float32)
+
+ # Dequantize and silu calculation
+ scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1))
+ input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+ input = input * scale_input
+
+ # pointers of gradient
+ grad_block_ptr = tl.make_block_ptr(
+ base=grad_ptr,
+ shape=(M, N),
+ strides=(grad_stride_0, grad_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # grad ptr
+ scale_grad_ptr = tl.make_block_ptr(
+ base=grad_scale_ptr,
+ shape=(M, SN),
+ strides=(s_grad_stride_0, s_grad_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ grad = tl.load(grad_block_ptr)
+ scale_grad = tl.load(scale_grad_ptr)
+
+ grad = grad.to(tl.float32)
+ scale_grad = scale_grad.to(tl.float32)
+
+ # Dequantize and silu calculation
+ scale_grad = tl.reshape(scale_grad, (BLOCK_M, BLOCK_SN, 1))
+ grad = tl.reshape(grad, (BLOCK_M, BLOCK_SN, QB))
+ grad = grad * scale_grad
+
+ # Actual Calculation of SiLU's backward
+ sigmoid = 1 / (1.0 + libdevice.exp(-input))
+ silu_output = sigmoid + input * sigmoid * (1 - sigmoid)
+ silu_output = silu_output * grad
+
+ # Quantize Scale calculation
+ abs_output = tl.abs(silu_output)
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
+
+ # Quantize
+ # silu_output = tl.fdiv(silu_output, scale_output)
+ silu_output = silu_output.to(output_ptr.type.element_ty)
+
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
+ silu_output = tl.reshape(silu_output, (BLOCK_M, BLOCK_N))
+
+ # debug
+ # silu_output = input
+ # scale_output = scale_input
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output_ptr = tl.make_block_ptr(
+ base=output_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output_stride_0, s_output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ tl.store(output_block_ptr, silu_output, boundary_check=(0, 1))
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
+
+
+def fp8_silu_backward_legacy(x, s_x, g, s_g, QB, stochastic=False): # Stochastic Rounding is left outside this function
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ assert len(s_x.shape) == 3
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+ s_x = s_x.reshape(-1, s_x.shape[-1])
+ g = g.reshape(-1, g.shape[-1])
+ s_g = s_g.reshape(-1, s_g.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x.shape
+ _, SN = s_x.shape # assume the shape of quantization block size is always 1 * G
+
+ y = torch.empty_like(g, dtype=torch.bfloat16)
+ s_y = torch.empty_like(s_g, dtype=s_g.dtype)
+ fp8MaxValue = FP8_MAX_VALUE[g.dtype] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_silu_backward_legacy_kernel[grid](
+ y,
+ s_y,
+ x,
+ s_x,
+ g,
+ s_g,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x.stride(0),
+ x.stride(1),
+ s_x.stride(0),
+ s_x.stride(1),
+ g.stride(0),
+ g.stride(1),
+ s_g.stride(0),
+ s_g.stride(1),
+ y.stride(0),
+ y.stride(1),
+ s_y.stride(0),
+ s_y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ s_y = s_y.reshape(BS, -1, s_y.shape[-1])
+
+ return y, s_y
diff --git a/llava/model/coat/activation/real_quantization/silu_fwd.py b/llava/model/coat/activation/real_quantization/silu_fwd.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e6dadc42b2e25ad548ec3e2f30cc146b8ce01d0
--- /dev/null
+++ b/llava/model/coat/activation/real_quantization/silu_fwd.py
@@ -0,0 +1,204 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+# 4 block
+import triton
+import triton.language as tl
+from triton.language.extra.cuda import libdevice
+
+try:
+ from .common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block
+except:
+ from common import FP8_MAX_VALUE, SCALE_MIN_THRES, get_configs_io_block
+
+"""SiLU Activation Forward"""
+"""Input uses 1 * 16 group quantization"""
+"""Output uses 1 * 16 group quantization"""
+"""The input can be 2D or 3D, but the calculation is performed in 2D"""
+
+
+@triton.autotune(
+ configs=[] + get_configs_io_block(),
+ key=[
+ "N",
+ ],
+)
+@triton.heuristics(
+ {
+ "BLOCK_SN": lambda args: args["BLOCK_N"] // args["QB"],
+ }
+)
+@triton.jit
+def _fp8_silu_forward_kernel(
+ output_ptr,
+ output_scale_ptr, # output
+ input_ptr,
+ input_scale_ptr, # input
+ M,
+ N,
+ SN,
+ QB: tl.constexpr,
+ fp8_max, # shape
+ input_stride_0,
+ input_stride_1, # input stride
+ s_input_stride_0,
+ s_input_stride_1, # scale of input stride
+ output_stride_0,
+ output_stride_1, # output stride
+ s_output_stride_0,
+ s_output_stride_1, # scale of output stride
+ SCALE_MIN_THRES: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_SN: tl.constexpr,
+): # CUDA block size
+
+ # Block PID
+ pid = tl.program_id(0)
+ NUM_BLOCK_N = tl.cdiv(N, BLOCK_N)
+ pid_dim0 = pid // NUM_BLOCK_N
+ pid_dim1 = pid % NUM_BLOCK_N
+
+ # pointers
+ input_block_ptr = tl.make_block_ptr(
+ base=input_ptr,
+ shape=(M, N),
+ strides=(input_stride_0, input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # input ptr
+ scale_input_ptr = tl.make_block_ptr(
+ base=input_scale_ptr,
+ shape=(M, SN),
+ strides=(s_input_stride_0, s_input_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ input = tl.load(input_block_ptr)
+ scale_input = tl.load(scale_input_ptr)
+
+ input = input.to(tl.float32)
+ scale_input = scale_input.to(tl.float32)
+
+ # Dequantize and silu calculation
+ scale_input = tl.reshape(scale_input, (BLOCK_M, BLOCK_SN, 1))
+ input = tl.reshape(input, (BLOCK_M, BLOCK_SN, QB))
+ input = input * scale_input
+
+ # Actual Calculation of SiLU
+ sigmoid = 1 / (1.0 + libdevice.exp(-input))
+ silu_output = sigmoid * input
+
+ # Quantize Scale calculation
+ abs_output = tl.abs(silu_output)
+ max_val = tl.max(abs_output, axis=2) + SCALE_MIN_THRES
+ scale_output = max_val / fp8_max
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN, 1))
+
+ # Quantize
+ silu_output = tl.fdiv(silu_output, scale_output)
+ silu_output = silu_output.to(output_ptr.type.element_ty)
+
+ scale_output = scale_output.to(output_scale_ptr.type.element_ty)
+ scale_output = tl.reshape(scale_output, (BLOCK_M, BLOCK_SN))
+ silu_output = tl.reshape(silu_output, (BLOCK_M, BLOCK_N))
+
+ # debug
+ # silu_output = input
+ # scale_output = scale_input
+
+ # pointers
+ output_block_ptr = tl.make_block_ptr(
+ base=output_ptr,
+ shape=(M, N),
+ strides=(output_stride_0, output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+ scale_output_ptr = tl.make_block_ptr(
+ base=output_scale_ptr,
+ shape=(M, SN),
+ strides=(s_output_stride_0, s_output_stride_1),
+ offsets=(pid_dim0 * BLOCK_M, pid_dim1 * BLOCK_SN),
+ block_shape=(BLOCK_M, BLOCK_SN),
+ order=(1, 0),
+ )
+
+ tl.store(output_block_ptr, silu_output, boundary_check=(0, 1))
+ tl.store(scale_output_ptr, scale_output, boundary_check=(0, 1))
+
+
+def fp8_silu_forward(x, s_x, QB):
+ # Change batched 3D input to 2D
+ batched = False
+ if len(x.shape) == 3:
+ assert len(s_x.shape) == 3
+ batched = True
+ BS = x.shape[0]
+ x = x.reshape(-1, x.shape[-1])
+ s_x = s_x.reshape(-1, s_x.shape[-1])
+
+ # defining the input and output tensor
+ M, N = x.shape
+ _, SN = s_x.shape # assume the shape of quantization block size is always 1 * G
+
+ y = torch.empty_like(x, dtype=x.dtype)
+ s_y = torch.empty_like(s_x, dtype=s_x.dtype)
+ fp8MaxValue = FP8_MAX_VALUE[x.dtype] # E4M3 and E5M2 have different max value
+
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),)
+
+ _fp8_silu_forward_kernel[grid](
+ y,
+ s_y,
+ x,
+ s_x,
+ M,
+ N,
+ SN,
+ QB,
+ fp8MaxValue,
+ x.stride(0),
+ x.stride(1),
+ s_x.stride(0),
+ s_x.stride(1),
+ y.stride(0),
+ y.stride(1),
+ s_y.stride(0),
+ s_y.stride(1),
+ SCALE_MIN_THRES=SCALE_MIN_THRES,
+ )
+
+ # Recover 2D to 3D
+ if batched:
+ y = y.reshape(BS, -1, y.shape[-1])
+ s_y = s_y.reshape(BS, -1, s_y.shape[-1])
+
+ return y, s_y
diff --git a/llava/model/coat/activation/utils.py b/llava/model/coat/activation/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b14b79e29cac93f6049c6fa81ff537bd2a5cb8fd
--- /dev/null
+++ b/llava/model/coat/activation/utils.py
@@ -0,0 +1,33 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import os
+from collections import defaultdict
+
+import torch
+
+
+def quant_get_local_rank() -> int:
+ return int(os.environ.get("LOCAL_RANK") or 0)
+
+
+record_memory_allocated = defaultdict(list)
diff --git a/llava/model/coat/fp8_trainer.py b/llava/model/coat/fp8_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b9d68029db9fc7fbbd6af01b34d2cf342b4294f
--- /dev/null
+++ b/llava/model/coat/fp8_trainer.py
@@ -0,0 +1,626 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# In this file I only add the logic in line 411 to 415. The rest remains unchanged compared with the original Trainer Class.
+
+import math
+import os
+import shutil
+import time
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from packaging import version
+from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
+from transformers import Trainer
+from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
+from transformers.integrations import hp_params
+from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
+from transformers.integrations.tpu import tpu_spmd_dataloader
+from transformers.trainer_callback import ExportableState, TrainerState
+from transformers.trainer_pt_utils import get_model_param_count
+from transformers.trainer_utils import HPSearchBackend, TrainOutput, has_length, speed_metrics
+from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments
+from transformers.utils import (
+ is_accelerate_available,
+ is_apex_available,
+ is_bitsandbytes_available,
+ is_datasets_available,
+ is_galore_torch_available,
+ is_grokadamw_available,
+ is_in_notebook,
+ is_ipex_available,
+ is_liger_kernel_available,
+ is_lomo_available,
+ is_peft_available,
+ is_safetensors_available,
+ is_sagemaker_dp_enabled,
+ is_sagemaker_mp_enabled,
+ is_schedulefree_available,
+ is_torch_compile_available,
+ is_torch_mlu_available,
+ is_torch_mps_available,
+ is_torch_musa_available,
+ is_torch_neuroncore_available,
+ is_torch_npu_available,
+ is_torch_xla_available,
+ is_torch_xpu_available,
+ is_torchao_available,
+ logging,
+)
+
+if is_apex_available():
+ from apex import amp
+
+if is_datasets_available():
+ import datasets
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+ import torch_xla.debug.metrics as met
+ from torch_xla import __version__ as XLA_VERSION
+
+ IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION)
+ if IS_XLA_FSDPV2_POST_2_2:
+ import torch_xla.distributed.spmd as xs
+ import torch_xla.runtime as xr
+else:
+ IS_XLA_FSDPV2_POST_2_2 = False
+
+
+if is_sagemaker_mp_enabled():
+ import smdistributed.modelparallel.torch as smp
+ from smdistributed.modelparallel import __version__ as SMP_VERSION
+
+ IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
+
+ from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
+else:
+ IS_SAGEMAKER_MP_POST_1_10 = False
+
+
+if is_safetensors_available():
+ import safetensors.torch
+
+if is_peft_available():
+ from peft import PeftModel
+
+
+if is_accelerate_available():
+ from accelerate import Accelerator
+ from accelerate import __version__ as accelerate_version
+ from accelerate import skip_first_batches
+ from accelerate.utils import (
+ DistributedDataParallelKwargs,
+ DistributedType,
+ GradientAccumulationPlugin,
+ load_fsdp_model,
+ load_fsdp_optimizer,
+ save_fsdp_model,
+ save_fsdp_optimizer,
+ )
+
+ DATA_SAMPLERS = [RandomSampler]
+ if version.parse(accelerate_version) > version.parse("0.23.0"):
+ from accelerate.data_loader import SeedableRandomSampler
+
+ DATA_SAMPLERS += [SeedableRandomSampler]
+
+ if is_deepspeed_available():
+ from accelerate.utils import DeepSpeedSchedulerWrapper
+
+if is_accelerate_available("0.28.0"):
+ from accelerate.utils import DataLoaderConfiguration
+
+from .activation.models._fp8manager import FP8Manager
+
+logger = logging.get_logger(__name__)
+
+
+class CoatFP8Trainer(Trainer):
+ def _inner_training_loop(
+ self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
+ ):
+ self.accelerator.free_memory()
+ self._train_batch_size = batch_size
+ if self.args.auto_find_batch_size:
+ if self.state.train_batch_size != self._train_batch_size:
+ from accelerate.utils import release_memory
+
+ (self.model_wrapped,) = release_memory(self.model_wrapped)
+ self.model_wrapped = self.model
+
+ # Check for DeepSpeed *after* the intial pass and modify the config
+ if self.is_deepspeed_enabled:
+ # Temporarily unset `self.args.train_batch_size`
+ original_bs = self.args.per_device_train_batch_size
+ self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu)
+ self.propagate_args_to_deepspeed(True)
+ self.args.per_device_train_batch_size = original_bs
+ self.state.train_batch_size = self._train_batch_size
+ logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
+ # Data loader and number of training steps
+ train_dataloader = self.get_train_dataloader()
+ if self.is_fsdp_xla_v2_enabled:
+ train_dataloader = tpu_spmd_dataloader(train_dataloader)
+
+ # Setting up training control variables:
+ # number of training epochs: num_train_epochs
+ # number of training steps per epoch: num_update_steps_per_epoch
+ # total number of training steps to execute: max_steps
+ total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size
+
+ len_dataloader = None
+ num_train_tokens = None
+ if has_length(train_dataloader):
+ len_dataloader = len(train_dataloader)
+ num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
+ num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
+ num_examples = self.num_examples(train_dataloader)
+ if args.max_steps > 0:
+ max_steps = args.max_steps
+ num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
+ args.max_steps % num_update_steps_per_epoch > 0
+ )
+ # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
+ # the best we can do.
+ num_train_samples = args.max_steps * total_train_batch_size
+ if args.include_tokens_per_second:
+ num_train_tokens = (
+ self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
+ )
+ else:
+ max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
+ num_train_epochs = math.ceil(args.num_train_epochs)
+ num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
+ if args.include_tokens_per_second:
+ num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs
+ elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
+ max_steps = args.max_steps
+ # Setting a very large number of epochs so we go as many times as necessary over the iterator.
+ num_train_epochs = sys.maxsize
+ num_update_steps_per_epoch = max_steps
+ num_examples = total_train_batch_size * args.max_steps
+ num_train_samples = args.max_steps * total_train_batch_size
+ if args.include_tokens_per_second:
+ num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
+ else:
+ raise ValueError(
+ "args.max_steps must be set to a positive value if dataloader does not have a length, was"
+ f" {args.max_steps}"
+ )
+
+ if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
+ if self.args.n_gpu > 1:
+ # nn.DataParallel(model) replicates the model, creating new variables and module
+ # references registered here no longer work on other gpus, breaking the module
+ raise ValueError(
+ "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
+ " (torchrun or torch.distributed.launch (deprecated))."
+ )
+ else:
+ debug_overflow = DebugUnderflowOverflow(self.model) # noqa
+
+ delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
+
+ # We need to reset the scheduler, as its parameters may be different on subsequent calls
+ if self._created_lr_scheduler:
+ self.lr_scheduler = None
+ self._created_lr_scheduler = False
+
+ if self.is_deepspeed_enabled:
+ self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)
+
+ if not delay_optimizer_creation:
+ self.create_optimizer_and_scheduler(num_training_steps=max_steps)
+
+ self.state = TrainerState(
+ stateful_callbacks=[
+ cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
+ ]
+ )
+ self.state.is_hyper_param_search = trial is not None
+ self.state.train_batch_size = self._train_batch_size
+
+ # Compute absolute values for logging, eval, and save if given as ratio
+ if args.logging_steps is not None:
+ if args.logging_steps < 1:
+ self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
+ else:
+ self.state.logging_steps = args.logging_steps
+ if args.eval_steps is not None:
+ if args.eval_steps < 1:
+ self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
+ else:
+ self.state.eval_steps = args.eval_steps
+ if args.save_steps is not None:
+ if args.save_steps < 1:
+ self.state.save_steps = math.ceil(max_steps * args.save_steps)
+ else:
+ self.state.save_steps = args.save_steps
+
+ # Activate gradient checkpointing if needed
+ if args.gradient_checkpointing:
+ self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs)
+
+ model = self._wrap_model(self.model_wrapped)
+
+ # as the model is wrapped, don't use `accelerator.prepare`
+ # this is for unhandled cases such as
+ # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
+ use_accelerator_prepare = True if model is self.model else False
+
+ if delay_optimizer_creation:
+ if use_accelerator_prepare:
+ self._fsdp_qlora_plugin_updates()
+ self.model = self.accelerator.prepare(self.model)
+ self.create_optimizer_and_scheduler(num_training_steps=max_steps)
+
+ # prepare using `accelerator` prepare
+ if use_accelerator_prepare:
+ self.model.train()
+ if hasattr(self.lr_scheduler, "step"):
+ if self.use_apex:
+ model = self.accelerator.prepare(self.model)
+ else:
+ model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
+ else:
+ # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
+ model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
+ self.model, self.optimizer, self.lr_scheduler
+ )
+ elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
+ # In this case we are in DDP + LOMO, which should be supported
+ self.optimizer = self.accelerator.prepare(self.optimizer)
+
+ if self.is_fsdp_enabled:
+ self.model = self.model_wrapped = model
+
+ # for the rest of this function `model` is the outside model, whether it was wrapped or not
+ if model is not self.model:
+ self.model_wrapped = model
+
+ # backward compatibility
+ if self.is_deepspeed_enabled:
+ self.deepspeed = self.model_wrapped
+
+ # ckpt loading
+ if resume_from_checkpoint is not None:
+ if self.is_deepspeed_enabled:
+ deepspeed_load_checkpoint(
+ self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model)
+ )
+ elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
+ self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)
+
+ # Check if saved optimizer or scheduler states exist
+ self._load_optimizer_and_scheduler(resume_from_checkpoint)
+
+ # important: at this point:
+ # self.model is the Transformers Model
+ # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model),
+ # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.
+
+ # Train!
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {num_examples:,}")
+ logger.info(f" Num Epochs = {num_train_epochs:,}")
+ logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
+ if self.args.per_device_train_batch_size != self._train_batch_size:
+ logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {max_steps:,}")
+ logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
+
+ self.state.epoch = 0
+ start_time = time.time()
+ epochs_trained = 0
+ steps_trained_in_current_epoch = 0
+ steps_trained_progress_bar = None
+
+ # Check if continuing training from a checkpoint
+ if resume_from_checkpoint is not None and os.path.isfile(
+ os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
+ ):
+ self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
+ self.compare_trainer_and_checkpoint_args(self.args, self.state)
+ self._load_callback_state()
+ epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
+ if not args.ignore_data_skip:
+ steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
+ steps_trained_in_current_epoch *= args.gradient_accumulation_steps
+ else:
+ steps_trained_in_current_epoch = 0
+
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
+ logger.info(f" Continuing training from epoch {epochs_trained}")
+ logger.info(f" Continuing training from global step {self.state.global_step}")
+ if not args.ignore_data_skip:
+ logger.info(
+ f" Will skip the first {epochs_trained} epochs then the first"
+ f" {steps_trained_in_current_epoch} batches in the first epoch."
+ )
+
+ # Update the references
+ self.callback_handler.model = self.model
+ self.callback_handler.optimizer = self.optimizer
+ self.callback_handler.lr_scheduler = self.lr_scheduler
+ self.callback_handler.train_dataloader = train_dataloader
+ if self.hp_name is not None and self._trial is not None:
+ # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
+ # parameter to Train when using DDP.
+ self.state.trial_name = self.hp_name(self._trial)
+ if trial is not None:
+ assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
+ self.state.trial_params = hp_params(assignments)
+ else:
+ self.state.trial_params = None
+ # This should be the same if the state has been saved but in case the training arguments changed, it's safer
+ # to set this after the load.
+ self.state.max_steps = max_steps
+ self.state.num_train_epochs = num_train_epochs
+ self.state.is_local_process_zero = self.is_local_process_zero()
+ self.state.is_world_process_zero = self.is_world_process_zero()
+
+ # tr_loss is a tensor to avoid synchronization of TPUs through .item()
+ tr_loss = torch.tensor(0.0).to(args.device)
+ # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
+ self._total_loss_scalar = 0.0
+ self._globalstep_last_logged = self.state.global_step
+ model.zero_grad()
+ grad_norm: Optional[float] = None
+ self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
+
+ if args.eval_on_start:
+ self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
+
+ total_batched_samples = 0
+ for epoch in range(epochs_trained, num_train_epochs):
+ epoch_iterator = train_dataloader
+ if hasattr(epoch_iterator, "set_epoch"):
+ epoch_iterator.set_epoch(epoch)
+
+ # Reset the past mems state at the beginning of each epoch if necessary.
+ if args.past_index >= 0:
+ self._past = None
+
+ steps_in_epoch = (
+ len(epoch_iterator) if len_dataloader is not None else args.max_steps * args.gradient_accumulation_steps
+ )
+ self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
+
+ if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
+ self._load_rng_state(resume_from_checkpoint)
+
+ rng_to_sync = False
+ steps_skipped = 0
+ if steps_trained_in_current_epoch > 0:
+ epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
+ steps_skipped = steps_trained_in_current_epoch
+ steps_trained_in_current_epoch = 0
+ rng_to_sync = True
+
+ step = -1
+ for step, inputs in enumerate(epoch_iterator):
+ # NOTE: FP8 related
+ if total_batched_samples % args.gradient_accumulation_steps == 0:
+ FP8Manager.is_first_microbatch = True
+ else:
+ FP8Manager.is_first_microbatch = False
+
+ total_batched_samples += 1
+
+ if self.args.include_num_input_tokens_seen:
+ main_input_name = getattr(self.model, "main_input_name", "input_ids")
+ if main_input_name not in inputs:
+ logger.warning(
+ "Tried to track the number of tokens seen, however the current model is "
+ "not configured properly to know what item is the input. To fix this, add "
+ "a `main_input_name` attribute to the model class you are using."
+ )
+ else:
+ self.state.num_input_tokens_seen += (
+ torch.sum(
+ self.accelerator.gather(
+ torch.tensor(
+ inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64
+ )
+ )
+ )
+ .cpu()
+ .item()
+ )
+ if rng_to_sync:
+ self._load_rng_state(resume_from_checkpoint)
+ rng_to_sync = False
+
+ # Skip past any already trained steps if resuming training
+ if steps_trained_in_current_epoch > 0:
+ steps_trained_in_current_epoch -= 1
+ if steps_trained_progress_bar is not None:
+ steps_trained_progress_bar.update(1)
+ if steps_trained_in_current_epoch == 0:
+ self._load_rng_state(resume_from_checkpoint)
+ continue
+ elif steps_trained_progress_bar is not None:
+ steps_trained_progress_bar.close()
+ steps_trained_progress_bar = None
+
+ if step % args.gradient_accumulation_steps == 0:
+ self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
+
+ with self.accelerator.accumulate(model):
+ tr_loss_step = self.training_step(model, inputs)
+
+ if (
+ args.logging_nan_inf_filter
+ and not is_torch_xla_available()
+ and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
+ ):
+ # if loss is nan or inf simply add the average of previous logged losses
+ tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
+ else:
+ if tr_loss.device != tr_loss_step.device:
+ raise ValueError(
+ f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}"
+ )
+ tr_loss += tr_loss_step
+
+ self.current_flos += float(self.floating_point_ops(inputs))
+
+ is_last_step_and_steps_less_than_grad_acc = (
+ steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
+ )
+
+ if (
+ total_batched_samples % args.gradient_accumulation_steps == 0
+ or
+ # last step in epoch but step is always smaller than gradient_accumulation_steps
+ is_last_step_and_steps_less_than_grad_acc
+ ):
+ # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered
+ # in accelerate. So, explicitly enable sync gradients to True in that case.
+ if is_last_step_and_steps_less_than_grad_acc:
+ self.accelerator.gradient_state._set_sync_gradients(True)
+
+ # Gradient clipping
+ if args.max_grad_norm is not None and args.max_grad_norm > 0:
+ # deepspeed does its own clipping
+
+ if is_sagemaker_mp_enabled() and args.fp16:
+ _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
+ elif self.use_apex:
+ # Revert to normal clipping otherwise, handling Apex or full precision
+ _grad_norm = nn.utils.clip_grad_norm_(
+ amp.master_params(self.optimizer),
+ args.max_grad_norm,
+ )
+ else:
+ _grad_norm = self.accelerator.clip_grad_norm_(
+ model.parameters(),
+ args.max_grad_norm,
+ )
+
+ if is_accelerate_available() and self.accelerator.distributed_type == DistributedType.DEEPSPEED:
+ grad_norm = model.get_global_grad_norm()
+ # In some cases the grad norm may not return a float
+ if hasattr(grad_norm, "item"):
+ grad_norm = grad_norm.item()
+ else:
+ grad_norm = _grad_norm
+
+ self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)
+
+ self.optimizer.step()
+
+ self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
+
+ optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
+ if optimizer_was_run:
+ # Delay optimizer scheduling until metrics are generated
+ if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
+ self.lr_scheduler.step()
+
+ model.zero_grad()
+ self.state.global_step += 1
+ self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
+ self.control = self.callback_handler.on_step_end(args, self.state, self.control)
+
+ self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
+ else:
+ self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
+
+ if self.control.should_epoch_stop or self.control.should_training_stop:
+ # PyTorch/XLA relies on the data loader to insert the mark_step for
+ # each step. Since we are breaking the loop early, we need to manually
+ # insert the mark_step here.
+ if is_torch_xla_available():
+ xm.mark_step()
+ break
+ if step < 0:
+ logger.warning(
+ "There seems not to be a single sample in your epoch_iterator, stopping training at step"
+ f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
+ f" num_steps ({max_steps}) higher than the number of available samples."
+ )
+ self.control.should_training_stop = True
+
+ self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
+ self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
+
+ if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
+ if is_torch_xla_available():
+ # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
+ xm.master_print(met.metrics_report())
+ else:
+ logger.warning(
+ "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
+ "configured. Check your training configuration if this is unexpected."
+ )
+ if self.control.should_training_stop:
+ break
+
+ if args.past_index and hasattr(self, "_past"):
+ # Clean the state at the end of training
+ delattr(self, "_past")
+
+ logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
+ if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
+ # Wait for everyone to get here so we are sure the model has been saved by process 0.
+ if is_torch_xla_available():
+ xm.rendezvous("load_best_model_at_end")
+ elif args.parallel_mode == ParallelMode.DISTRIBUTED:
+ dist.barrier()
+ elif is_sagemaker_mp_enabled():
+ smp.barrier()
+
+ self._load_best_model()
+
+ # add remaining tr_loss
+ self._total_loss_scalar += tr_loss.item()
+ effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError
+ train_loss = self._total_loss_scalar / effective_global_step
+
+ metrics = speed_metrics(
+ "train",
+ start_time,
+ num_samples=num_train_samples,
+ num_steps=self.state.max_steps,
+ num_tokens=num_train_tokens,
+ )
+ self.store_flos()
+ metrics["total_flos"] = self.state.total_flos
+ metrics["train_loss"] = train_loss
+
+ self.is_in_train = False
+
+ self._memory_tracker.stop_and_update_metrics(metrics)
+
+ self.log(metrics)
+
+ run_dir = self._get_output_dir(trial)
+ checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)
+
+ # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
+ if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
+ for checkpoint in checkpoints_sorted:
+ if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
+ logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
+ shutil.rmtree(checkpoint, ignore_errors=True)
+
+ self.control = self.callback_handler.on_train_end(args, self.state, self.control)
+
+ # Wait for the checkpoint to be uploaded.
+ self._finish_current_push()
+
+ # After training we make sure to retrieve back the original forward pass method
+ # for the embedding layer by removing the forward post hook.
+ if self.neftune_noise_alpha is not None:
+ self._deactivate_neftune(self.model)
+
+ return TrainOutput(self.state.global_step, train_loss, metrics)
diff --git a/llava/model/coat/optimizer/fp8_adamw.py b/llava/model/coat/optimizer/fp8_adamw.py
new file mode 100644
index 0000000000000000000000000000000000000000..7776ef1539e07773285cf0d220811d6e73070542
--- /dev/null
+++ b/llava/model/coat/optimizer/fp8_adamw.py
@@ -0,0 +1,515 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from collections import OrderedDict, defaultdict
+from copy import deepcopy
+from itertools import chain
+from typing import Any, DefaultDict, Dict, Hashable, Iterable, List, Optional, Tuple, Union
+
+import qoptim_cuda
+import torch
+from torch import Tensor
+from torch.optim.optimizer import Optimizer
+from typing_extensions import ParamSpec, Self, TypeAlias
+
+StateDict: TypeAlias = Dict[str, Any]
+
+convert_str_to_fp8 = {"E4M3": torch.float8_e4m3fn, "E5M2": torch.float8_e5m2}
+
+
+class CoatAdamW(Optimizer):
+ def __init__(
+ self,
+ qargs,
+ params,
+ lr: float = 1e-3,
+ betas: Tuple[float, float] = (0.9, 0.999),
+ eps: float = 1e-8,
+ weight_decay: float = 1e-2,
+ amsgrad: bool = False,
+ *,
+ fused: Optional[bool] = None,
+ ):
+ self.qargs = qargs
+ assert self.qargs.first_order_expansion == self.qargs.second_order_expansion
+ if not 0.0 <= lr:
+ raise ValueError(f"Invalid learning rate: {lr}")
+ if not 0.0 <= eps:
+ raise ValueError(f"Invalid epsilon value: {eps}")
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
+ if not 0.0 <= weight_decay:
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
+ defaults = dict(
+ lr=lr,
+ betas=betas,
+ eps=eps,
+ weight_decay=weight_decay,
+ amsgrad=amsgrad,
+ fused=fused,
+ )
+ super().__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super().__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault("amsgrad", False)
+ fused = group.setdefault("fused", None)
+ for p in group["params"]:
+ p_state = self.state.get(p, [])
+ if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
+ step_val = float(p_state["step"])
+ p_state["step"] = torch.tensor(step_val, dtype=torch.float32)
+
+ def _init_group(
+ self,
+ group,
+ params_with_grad,
+ grads,
+ amsgrad,
+ use_expansion,
+ exp_avgs,
+ scale_exp_avgs,
+ expand_exp_avgs,
+ sqrt_minmax_exp_avgs,
+ exp_avg_sqs,
+ scale_exp_avg_sqs,
+ expand_exp_avg_sqs,
+ sqrt_minmax_exp_avg_sqs,
+ max_exp_avg_sqs,
+ state_steps,
+ ):
+ for p in group["params"]:
+ if p.grad is None:
+ continue
+ params_with_grad.append(p)
+ if p.grad.is_sparse:
+ raise RuntimeError("AdamW does not support sparse gradients")
+ grads.append(p.grad)
+
+ state = self.state[p]
+
+ # print(f'Param shape: {p.shape}', file=open('debug.txt', 'a'))
+ # print(f'Param shape: {p.shape}, {p.device}')
+
+ # State initialization
+ if len(state) == 0:
+ # This is because kernel launches are costly on CUDA and XLA.
+ state["step"] = torch.tensor(0.0)
+
+ # Should be torch.float8_e4m3fn
+ first_order_dtype = convert_str_to_fp8[self.qargs.first_order_bit]
+ second_order_dtype = convert_str_to_fp8[self.qargs.second_order_bit]
+ scale_shape = (p.numel() + self.qargs.qgroup_size - 1) // self.qargs.qgroup_size
+
+ # Exponential moving average of gradient values
+ state["exp_avg"] = torch.zeros_like(p, dtype=first_order_dtype, memory_format=torch.preserve_format)
+ state["scale_exp_avg"] = torch.zeros(scale_shape, device=p.device, dtype=p.dtype)
+ if use_expansion:
+ state["expand_exp_avg"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype)
+ state["sqrt_minmax_exp_avg"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype)
+ # Exponential moving average of squared gradient values
+ state["exp_avg_sq"] = torch.zeros_like(p, dtype=second_order_dtype, memory_format=torch.preserve_format)
+ state["scale_exp_avg_sq"] = torch.zeros(scale_shape, device=p.device, dtype=p.dtype)
+ if use_expansion:
+ state["expand_exp_avg_sq"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype)
+ state["sqrt_minmax_exp_avg_sq"] = torch.ones(scale_shape, device=p.device, dtype=p.dtype)
+ if amsgrad:
+ # Maintains max of all exp. moving avg. of sq. grad. values
+ state["max_exp_avg_sq"] = torch.zeros(p, memory_format=torch.preserve_format)
+
+ exp_avgs.append(state["exp_avg"])
+ scale_exp_avgs.append(state["scale_exp_avg"])
+ if use_expansion:
+ expand_exp_avgs.append(state["expand_exp_avg"])
+ sqrt_minmax_exp_avgs.append(state["sqrt_minmax_exp_avg"])
+ exp_avg_sqs.append(state["exp_avg_sq"])
+ scale_exp_avg_sqs.append(state["scale_exp_avg_sq"])
+ if use_expansion:
+ expand_exp_avg_sqs.append(state["expand_exp_avg_sq"])
+ sqrt_minmax_exp_avg_sqs.append(state["sqrt_minmax_exp_avg_sq"])
+
+ if group["amsgrad"]:
+ max_exp_avg_sqs.append(state["max_exp_avg_sq"])
+
+ state_steps.append(state["step"])
+
+ @torch._disable_dynamo
+ def load_state_dict(self, state_dict: StateDict) -> None:
+ r"""Loads the optimizer state.
+
+ Args:
+ state_dict (dict): optimizer state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ # shallow copy, to be consistent with module API
+ state_dict = state_dict.copy()
+
+ for pre_hook in self._optimizer_load_state_dict_pre_hooks.values():
+ hook_result = pre_hook(self, state_dict)
+ if hook_result is not None:
+ state_dict = hook_result
+
+ # Validate the state_dict
+ groups = self.param_groups
+
+ # Deepcopy as we write into saved_groups later to update state
+ saved_groups = deepcopy(state_dict["param_groups"])
+
+ if len(groups) != len(saved_groups):
+ raise ValueError("loaded state dict has a different number of " "parameter groups")
+ param_lens = (len(g["params"]) for g in groups)
+ saved_lens = (len(g["params"]) for g in saved_groups)
+ if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
+ raise ValueError(
+ "loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group"
+ )
+
+ # Update the state
+ id_map = dict(
+ zip(
+ chain.from_iterable(g["params"] for g in saved_groups), chain.from_iterable(g["params"] for g in groups)
+ )
+ )
+
+ def _cast(param, value, param_id=None, param_groups=None, key=None):
+ r"""Make a deep copy of value, casting all tensors to device of param."""
+ if isinstance(value, torch.Tensor):
+ return CoatAdamW._process_value_according_to_param_policy(param, value, param_id, param_groups, key)
+ elif isinstance(value, dict):
+ return {
+ k: _cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items()
+ }
+ elif isinstance(value, Iterable):
+ return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg]
+ else:
+ return value
+
+ # Copy state assigned to params (and cast tensors to appropriate types).
+ # State that is not assigned to params is copied as is (needed for
+ # backward compatibility).
+ state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict)
+ for k, v in state_dict["state"].items():
+ if k in id_map:
+ param = id_map[k]
+ state[param] = _cast(param, v, param_id=k, param_groups=state_dict["param_groups"])
+ else:
+ state[k] = v
+
+ # Update parameter groups, setting their 'params' value
+ def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]:
+ new_group["params"] = group["params"]
+ return new_group
+
+ param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
+ self.__setstate__({"state": state, "param_groups": param_groups})
+
+ for post_hook in self._optimizer_load_state_dict_post_hooks.values():
+ post_hook(self)
+
+ @staticmethod
+ def _process_value_according_to_param_policy(
+ param: torch.Tensor,
+ value: torch.Tensor,
+ param_id: int,
+ param_groups: List[Dict[Any, Any]],
+ key: Hashable = None,
+ ) -> torch.Tensor:
+ # Floating-point types are a bit special here. They are the only ones
+ # that are assumed to always match the type of params.
+ # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
+ # UNLESS fused or capturable, see note [special device hosting for step]
+ fused = False
+ capturable = False
+ assert param_groups is not None
+ for pg in param_groups:
+ if param_id in pg["params"]:
+ fused = pg["fused"] if "fused" in pg else False
+ capturable = pg["capturable"] if "capturable" in pg else False
+ break
+ if key == "step":
+ if capturable or fused:
+ return value.to(dtype=torch.float32, device=param.device)
+ else:
+ return value
+ else:
+ assert value.dtype in [torch.float8_e4m3fn, torch.float8_e5m2, torch.float32]
+ return value.to(device=param.device) # do not cast optimizer states
+ # if param.is_floating_point():
+ # return value.to(dtype=param.dtype, device=param.device)
+ # else:
+ # return value.to(device=param.device)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Perform a single optimization step.
+
+ Args:
+ closure (Callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ self._cuda_graph_capture_health_check()
+
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ params_with_grad = []
+ grads = []
+ exp_avgs = []
+ scale_exp_avgs = []
+ expand_exp_avgs = []
+ sqrt_minmax_exp_avgs = []
+ exp_avg_sqs = []
+ scale_exp_avg_sqs = []
+ expand_exp_avg_sqs = []
+ sqrt_minmax_exp_avg_sqs = []
+ max_exp_avg_sqs = []
+ state_steps = []
+ amsgrad = group["amsgrad"]
+ use_expansion = self.qargs.first_order_expansion in ["expansion", "true"]
+ beta1, beta2 = group["betas"]
+
+ self._init_group(
+ group,
+ params_with_grad,
+ grads,
+ amsgrad,
+ use_expansion,
+ exp_avgs,
+ scale_exp_avgs,
+ expand_exp_avgs,
+ sqrt_minmax_exp_avgs,
+ exp_avg_sqs,
+ scale_exp_avg_sqs,
+ expand_exp_avg_sqs,
+ sqrt_minmax_exp_avg_sqs,
+ max_exp_avg_sqs,
+ state_steps,
+ )
+
+ Coatadamw(
+ self.qargs,
+ params_with_grad,
+ grads,
+ exp_avgs,
+ scale_exp_avgs,
+ expand_exp_avgs,
+ sqrt_minmax_exp_avgs,
+ exp_avg_sqs,
+ scale_exp_avg_sqs,
+ expand_exp_avg_sqs,
+ sqrt_minmax_exp_avg_sqs,
+ max_exp_avg_sqs,
+ state_steps,
+ amsgrad=amsgrad,
+ use_expansion=use_expansion,
+ beta1=beta1,
+ beta2=beta2,
+ lr=group["lr"],
+ weight_decay=group["weight_decay"],
+ eps=group["eps"],
+ qgroup_size=self.qargs.qgroup_size,
+ expand_min=self.qargs.expand_min,
+ fused=group["fused"],
+ grad_scale=getattr(self, "grad_scale", None),
+ found_inf=getattr(self, "found_inf", None),
+ )
+
+ return loss
+
+
+def Coatadamw(
+ qargs,
+ params: List[Tensor],
+ grads: List[Tensor],
+ exp_avgs: List[Tensor],
+ scale_exp_avgs: List[Tensor],
+ expand_exp_avgs: List[Tensor],
+ sqrt_minmax_exp_avgs: List[Tensor],
+ exp_avg_sqs: List[Tensor],
+ scale_exp_avg_sqs: List[Tensor],
+ expand_exp_avg_sqs: List[Tensor],
+ sqrt_minmax_exp_avg_sqs: List[Tensor],
+ max_exp_avg_sqs: List[Tensor],
+ state_steps: List[Tensor],
+ # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
+ # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
+ fused: Optional[bool] = None,
+ grad_scale: Optional[Tensor] = None,
+ found_inf: Optional[Tensor] = None,
+ *,
+ amsgrad: bool,
+ use_expansion: bool,
+ beta1: float,
+ beta2: float,
+ lr: Union[float, Tensor],
+ weight_decay: float,
+ eps: float,
+ qgroup_size: int,
+ expand_min: int,
+):
+ r"""Functional API that performs AdamW algorithm computation.
+
+ See :class:`~torch.optim.AdamW` for details.
+ """
+ if not torch._utils.is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps):
+ raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")
+
+ func = _single_tensor_Coatadamw
+
+ func(
+ qargs,
+ params,
+ grads,
+ exp_avgs,
+ scale_exp_avgs,
+ expand_exp_avgs,
+ sqrt_minmax_exp_avgs,
+ exp_avg_sqs,
+ scale_exp_avg_sqs,
+ expand_exp_avg_sqs,
+ sqrt_minmax_exp_avg_sqs,
+ max_exp_avg_sqs,
+ state_steps,
+ amsgrad=amsgrad,
+ use_expansion=use_expansion,
+ beta1=beta1,
+ beta2=beta2,
+ lr=lr,
+ weight_decay=weight_decay,
+ eps=eps,
+ qgroup_size=qgroup_size,
+ expand_min=expand_min,
+ grad_scale=grad_scale,
+ found_inf=found_inf,
+ )
+
+
+def _dispatch_sqrt(x: float): # float annotation is needed because of torchscript type inference
+ if not torch.jit.is_scripting() and isinstance(x, torch.Tensor):
+ return x.sqrt()
+ else:
+ return sqrt(x)
+
+
+def _single_tensor_Coatadamw(
+ qargs,
+ params: List[Tensor],
+ grads: List[Tensor],
+ exp_avgs: List[Tensor],
+ scale_exp_avgs: List[Tensor],
+ expand_exp_avgs: List[Tensor],
+ sqrt_minmax_exp_avgs: List[Tensor],
+ exp_avg_sqs: List[Tensor],
+ scale_exp_avg_sqs: List[Tensor],
+ expand_exp_avg_sqs: List[Tensor],
+ sqrt_minmax_exp_avg_sqs: List[Tensor],
+ max_exp_avg_sqs: List[Tensor],
+ state_steps: List[Tensor],
+ grad_scale: Optional[Tensor],
+ found_inf: Optional[Tensor],
+ *,
+ amsgrad: bool,
+ use_expansion: bool,
+ beta1: float,
+ beta2: float,
+ lr: Union[Tensor, float],
+ weight_decay: float,
+ eps: float,
+ qgroup_size: int,
+ expand_min: int,
+):
+
+ assert grad_scale is None and found_inf is None
+
+ if torch.jit.is_scripting():
+ # this assert is due to JIT being dumb and not realizing that the ops below
+ # have overloads to handle both float and Tensor lrs, so we just assert it's
+ # a float since most people using JIT are using floats
+ assert isinstance(lr, float)
+
+ for i, param in enumerate(params):
+ grad = grads[i]
+ # First order
+ exp_avg = exp_avgs[i]
+ scale_exp_avg = scale_exp_avgs[i]
+ # Second order
+ exp_avg_sq = exp_avg_sqs[i]
+ scale_exp_avg_sq = scale_exp_avg_sqs[i]
+ step_t = state_steps[i]
+
+ # print(len(exp_avg.unique()), len(exp_avg_sq.unique()))
+ # print(f"{param.shape}, {grad.shape}, {exp_avg.shape}, {exp_avg_sq.shape}", file=open('debug.txt', 'a'))
+
+ # update step
+ step_t += 1
+ step = int(step_t.item())
+
+ # Perform Optimizer Step
+ if use_expansion:
+ expand_exp_avg = expand_exp_avgs[i]
+ sqrt_minmax_exp_avg = sqrt_minmax_exp_avgs[i]
+ expand_exp_avg_sq = expand_exp_avg_sqs[i]
+ sqrt_minmax_exp_avg_sq = sqrt_minmax_exp_avg_sqs[i]
+
+ qoptim_cuda.fp8_adamw_expand_step(
+ param,
+ grad,
+ exp_avg,
+ scale_exp_avg,
+ expand_exp_avg,
+ sqrt_minmax_exp_avg,
+ exp_avg_sq,
+ scale_exp_avg_sq,
+ expand_exp_avg_sq,
+ sqrt_minmax_exp_avg_sq,
+ beta1,
+ beta2,
+ lr,
+ weight_decay,
+ eps,
+ step,
+ qgroup_size,
+ expand_min,
+ )
+
+ else:
+ qoptim_cuda.fp8_adamw_step(
+ param,
+ grad,
+ exp_avg,
+ scale_exp_avg,
+ exp_avg_sq,
+ scale_exp_avg_sq,
+ beta1,
+ beta2,
+ lr,
+ weight_decay,
+ eps,
+ step,
+ qgroup_size,
+ )
diff --git a/llava/model/coat/optimizer/kernels/bindings.cpp b/llava/model/coat/optimizer/kernels/bindings.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..aa3dbf8b145c263000ce3b7151bb73714c55f655
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/bindings.cpp
@@ -0,0 +1,10 @@
+#include
+
+#include "include/fp8_adamw.h"
+#include "include/fp8_adamw_expand.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("fp8_adamw_step", &FP8_AdamW, "Update the quantized optimizer states");
+ m.def("fp8_adamw_expand_step", &FP8_AdamW_expand,
+ "Update the quantized optimizer states, use polynomial expander");
+}
diff --git a/llava/model/coat/optimizer/kernels/build/lib.linux-x86_64-cpython-310/qoptim_cuda.cpython-310-x86_64-linux-gnu.so b/llava/model/coat/optimizer/kernels/build/lib.linux-x86_64-cpython-310/qoptim_cuda.cpython-310-x86_64-linux-gnu.so
new file mode 100644
index 0000000000000000000000000000000000000000..0d6753b8a2e35f20ec2070c6bce2c9e4acefa0d7
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/build/lib.linux-x86_64-cpython-310/qoptim_cuda.cpython-310-x86_64-linux-gnu.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e40fca6d032a0a3094e50a4fdb04e78998819de2652f0a82e49f49225dd46ba9
+size 235960
diff --git a/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/bindings.o b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/bindings.o
new file mode 100644
index 0000000000000000000000000000000000000000..be1ed8ce2592d7ce94e249fd52af112a39de1af2
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/bindings.o
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:daebebfbb9d1a3dbb5e22acb50af317aa74b437f7aa965bd959b19010c5fd144
+size 244976
diff --git a/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda.o b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda.o
new file mode 100644
index 0000000000000000000000000000000000000000..837c2b25b4a1759b529b7e5bf2f74a781a60b73c
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda.o
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ac9c6832be5dccfce3f25ae735c67459afffd12f8f747b5018b3121da3dfd2d7
+size 215440
diff --git a/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda_kernel.o b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda_kernel.o
new file mode 100644
index 0000000000000000000000000000000000000000..b6c5c480b055b909abffc646f69cd563c615cc47
Binary files /dev/null and b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_cuda_kernel.o differ
diff --git a/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda.o b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda.o
new file mode 100644
index 0000000000000000000000000000000000000000..ec6051ec3846f04b0f2c0f3bbd6abce6a626c3b5
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda.o
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:799978e9ba5a24e75b58d0e218cf522874bb983d15fee15adb02d9ac0f2d7a10
+size 216240
diff --git a/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda_kernel.o b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda_kernel.o
new file mode 100644
index 0000000000000000000000000000000000000000..b29913e078a46679c6aedca63accff5ace9ca5b6
Binary files /dev/null and b/llava/model/coat/optimizer/kernels/build/temp.linux-x86_64-cpython-310/fp8_adamw_expand_cuda_kernel.o differ
diff --git a/llava/model/coat/optimizer/kernels/csrc_expand_quantize/makefile b/llava/model/coat/optimizer/kernels/csrc_expand_quantize/makefile
new file mode 100644
index 0000000000000000000000000000000000000000..ad4632ffe641f2043c2cbc4966dff66af2807c26
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/csrc_expand_quantize/makefile
@@ -0,0 +1,6 @@
+all:
+ nvcc nvcc_qoptim.cu -o nvcc_qoptim -gencode=arch=compute_90,code=compute_90
+run:
+ ./nvcc_qoptim
+clean:
+ rm -f nvcc_qoptim
diff --git a/llava/model/coat/optimizer/kernels/csrc_expand_quantize/nvcc_qoptim.cu b/llava/model/coat/optimizer/kernels/csrc_expand_quantize/nvcc_qoptim.cu
new file mode 100644
index 0000000000000000000000000000000000000000..79f370e76aaf3b99fc53c7fe053731dc9358729f
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/csrc_expand_quantize/nvcc_qoptim.cu
@@ -0,0 +1,478 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+
+namespace cg = cooperative_groups;
+#define WARPSIZE 32
+#define QGROUPSIZE 128
+#define QUANT_MIN_VAL 1e-20
+
+template
+inline float fp8_dtype_max(const T &variable) {
+ if (std::is_same::value) {
+ return 448;
+ } else if (std::is_same::value) {
+ return 57344;
+ } else {
+ throw "Unsupported data format";
+ }
+}
+
+typedef enum { fp8_adamw } myCsrcKernels;
+
+void fp8_adamw_cpu(float *params, float *grads, float *fp_exp_avg,
+ float *fp_exp_avg_sq, float beta1, float beta2, float lr,
+ float wd, float eps, int step, int qgroup_size, int M,
+ int N) {
+ for (int idx = 0; idx < M * N; idx++) {
+ fp_exp_avg[idx] = beta1 * fp_exp_avg[idx] + (1 - beta1) * grads[idx];
+ fp_exp_avg_sq[idx] =
+ beta2 * fp_exp_avg_sq[idx] + (1 - beta2) * grads[idx] * grads[idx];
+
+ const float correction1 = 1.0f - powf(beta1, step);
+ const float correction2_sqrt = sqrtf(1.0f - powf(beta2, step));
+
+ float denom =
+ (sqrtf(fp_exp_avg_sq[idx]) / correction2_sqrt + eps) * correction1;
+ float update = (fp_exp_avg[idx] / denom) + (wd * params[idx]);
+ params[idx] = params[idx] - (lr * update);
+ }
+}
+
+template
+void printFloatArrayToFile(T *array, int M, int N,
+ const std::string &outputFileName) {
+ std::ofstream outputFile(outputFileName);
+ if (!outputFile.is_open()) {
+ std::cout << "Failed to open the file." << std::endl;
+ return;
+ }
+
+ for (int i = 0; i < M; i++) {
+ for (int j = 0; j < N; j++) {
+ int index = i * N + j;
+ outputFile << std::setw(10) << std::right << std::fixed
+ << std::setprecision(6) << (float)array[index] << " ";
+ if (j == N - 1) {
+ outputFile << "\n";
+ }
+ }
+ }
+}
+
+template
+__global__ void fp8_adamw_csrc(
+ scalar_t *__restrict__ params, scalar_t *__restrict__ grads,
+ __nv_fp8_e4m3 *__restrict__ exp_avg, float *__restrict__ scale_exp_avg,
+ float *__restrict__ expand_exp_avg, float *__restrict__ sqrtminmax_exp_avg,
+ __nv_fp8_e4m3 *__restrict__ exp_avg_sq,
+ float *__restrict__ scale_exp_avg_sq, float *__restrict__ expand_exp_avg_sq,
+ float *__restrict__ sqrtminmax_exp_avg_sq, float beta1, float beta2,
+ float lr, float wd, float eps, int step, int qgroup_size, int expand_min,
+ int total_elements, int total_scale_elements) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int scale_idx = blockIdx.x;
+
+ float float_exp_avg, float_exp_avg_sq;
+ float correction1, correction2_sqrt;
+ float denom, update;
+
+ if (idx < total_elements) {
+ // dequantize the optimizer states
+ float_exp_avg = float(exp_avg[idx]) * scale_exp_avg[scale_idx];
+ int sign_exp_avg = 1 - 2 * signbit(float_exp_avg);
+ float_exp_avg = sign_exp_avg *
+ powf(fabsf(float_exp_avg), 1 / expand_exp_avg[scale_idx]) *
+ sqrtminmax_exp_avg[scale_idx];
+ float_exp_avg_sq = float(exp_avg_sq[idx]) * scale_exp_avg_sq[scale_idx];
+ float_exp_avg_sq =
+ powf(float_exp_avg_sq, 1 / expand_exp_avg_sq[scale_idx]) *
+ sqrtminmax_exp_avg_sq[scale_idx];
+
+ // calculation of optimizer.step()
+ float_exp_avg = beta1 * float_exp_avg + (1 - beta1) * grads[idx];
+ float_exp_avg_sq =
+ beta2 * float_exp_avg_sq + (1 - beta2) * grads[idx] * grads[idx];
+
+ correction1 = 1.0f - powf(beta1, step);
+ correction2_sqrt = sqrtf(1.0f - powf(beta2, step));
+
+ denom = (sqrtf(float_exp_avg_sq) / correction2_sqrt + eps) * correction1;
+ update = (float_exp_avg / denom) + (wd * params[idx]);
+ params[idx] = params[idx] - (lr * update);
+ } else {
+ float_exp_avg = 0.0f;
+ float_exp_avg_sq = 0.0f;
+ }
+
+ //// quantize the first-order and second-order momentum
+ int wid = threadIdx.x / WARPSIZE;
+
+ // reduction within a warp
+ __shared__ float sharedFirstMaxVal[32];
+ __shared__ float sharedFirstMinVal[32];
+ __shared__ float sharedSecondMaxVal[32];
+ __shared__ float sharedSecondMinVal[32];
+ cg::thread_block_tile<32> warpTile =
+ cg::tiled_partition<32>(cg::this_thread_block());
+ float firstMaxVal = fabsf(float_exp_avg);
+ float firstMinVal = fabsf(float_exp_avg);
+ float secondMaxVal = fabsf(float_exp_avg_sq);
+ float secondMinVal = fabsf(float_exp_avg_sq);
+ // Special Handel
+ if (idx >= total_elements) {
+ firstMinVal = __int_as_float(0x7f7fffff);
+ secondMinVal = __int_as_float(0x7f7fffff);
+ }
+
+ for (int i = warpTile.size() / 2; i > 0; i /= 2) {
+ float reduceFirstMaxVal = warpTile.shfl_down(firstMaxVal, i);
+ float reduceFirstMinVal = warpTile.shfl_down(firstMinVal, i);
+ float reduceSecondMaxVal = warpTile.shfl_down(secondMaxVal, i);
+ float reduceSecondMinVal = warpTile.shfl_down(secondMinVal, i);
+ firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal));
+ firstMinVal = fmin(firstMinVal, fabsf(reduceFirstMinVal));
+ secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal));
+ secondMinVal = fmin(secondMinVal, fabsf(reduceSecondMinVal));
+ // printf("First Max: %f\n", reduceFirstMaxVal);
+ }
+ int lane = warpTile.thread_rank();
+ if (lane == 0) {
+ sharedFirstMaxVal[wid] = firstMaxVal;
+ sharedFirstMinVal[wid] = firstMinVal;
+ sharedSecondMaxVal[wid] = secondMaxVal;
+ sharedSecondMinVal[wid] = secondMinVal;
+ }
+
+ __syncthreads();
+
+ // reduction within a block
+ __shared__ float shared_absmax_exp_avg;
+ __shared__ float shared_absmin_exp_avg;
+ __shared__ float shared_absmax_exp_avg_sq;
+ __shared__ float shared_absmin_exp_avg_sq;
+ firstMaxVal =
+ (threadIdx.x < blockDim.x / warpSize) ? sharedFirstMaxVal[lane] : 0;
+ firstMinVal =
+ (threadIdx.x < blockDim.x / warpSize) ? sharedFirstMinVal[lane] : 1e9;
+ secondMaxVal =
+ (threadIdx.x < blockDim.x / warpSize) ? sharedSecondMaxVal[lane] : 0;
+ secondMinVal =
+ (threadIdx.x < blockDim.x / warpSize) ? sharedSecondMinVal[lane] : 1e9;
+ if (wid == 0) {
+ for (int offset = WARPSIZE / 2; offset > 0; offset /= 2) {
+ float reduceFirstMaxVal =
+ __shfl_down_sync(0xFFFFFFFF, firstMaxVal, offset);
+ float reduceFirstMinVal =
+ __shfl_down_sync(0xFFFFFFFF, firstMinVal, offset);
+ float reduceSecondMaxVal =
+ __shfl_down_sync(0xFFFFFFFF, secondMaxVal, offset);
+ float reduceSecondMinVal =
+ __shfl_down_sync(0xFFFFFFFF, secondMinVal, offset);
+ firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal));
+ firstMinVal = fmin(firstMinVal, fabsf(reduceFirstMinVal));
+ secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal));
+ secondMinVal = fmin(secondMinVal, fabsf(reduceSecondMinVal));
+ }
+ if (lane == 0) {
+ shared_absmax_exp_avg = firstMaxVal;
+ shared_absmin_exp_avg = firstMinVal;
+ shared_absmax_exp_avg_sq = secondMaxVal;
+ shared_absmin_exp_avg_sq = secondMinVal;
+ }
+ }
+
+ __syncthreads();
+
+ if (idx < total_elements) {
+ // float fp8MaxVal = fp8_dtype_max<__nv_fp8_e4m3>(exp_avg[idx]);
+ // scaling factor before expanding
+ float fp8MaxVal = 448;
+
+ // dynamic exponent quantization part
+ firstMaxVal = shared_absmax_exp_avg + QUANT_MIN_VAL;
+ firstMinVal = shared_absmin_exp_avg + QUANT_MIN_VAL;
+ secondMaxVal = shared_absmax_exp_avg_sq + QUANT_MIN_VAL;
+ secondMinVal = shared_absmin_exp_avg_sq + QUANT_MIN_VAL;
+
+ // calculate the ratio and make the scale to center
+ float firstRatio = firstMaxVal / firstMinVal;
+ float secondRatio = secondMaxVal / secondMinVal;
+ float firstSqrtMinMax = sqrt(firstMaxVal * firstMinVal);
+ float secondSqrtMinMax = sqrt(secondMaxVal * secondMinVal);
+
+ // printf("Max %f, Min %f, Origin %f \n", firstMaxVal, firstMinVal,
+ // float_exp_avg);
+
+ // since we use x^k expander, calculate the optimal expanding factor
+ float ratioUpperBound = fp8MaxVal * fp8MaxVal / 2;
+ float firstExp =
+ floor((log2f(ratioUpperBound) / log2f(firstRatio)) * expand_min) /
+ expand_min; // expand_min is set to 8 for example, then the firstExp is
+ // the multiple of 1/8
+ float secondExp =
+ floor((log2f(ratioUpperBound) / log2f(secondRatio)) * expand_min) /
+ expand_min;
+
+ int sign_exp_avg = 1 - 2 * signbit(float_exp_avg);
+ float_exp_avg =
+ sign_exp_avg * powf(fabsf(float_exp_avg) / firstSqrtMinMax, firstExp);
+ float_exp_avg_sq = powf(float_exp_avg_sq / secondSqrtMinMax, secondExp);
+
+ // correspondingly, change the scaling factor
+ float new_scale_exp_avg =
+ powf(firstMaxVal / firstSqrtMinMax, firstExp) / fp8MaxVal;
+ float new_scale_exp_avg_sq =
+ powf(secondMaxVal / secondSqrtMinMax, secondExp) / fp8MaxVal;
+
+ // quantize the optimizer states
+ __nv_fp8_e4m3 exp_avg_new =
+ static_cast<__nv_fp8_e4m3>(float_exp_avg / new_scale_exp_avg);
+ __nv_fp8_e4m3 exp_avg_sq_new =
+ static_cast<__nv_fp8_e4m3>(float_exp_avg_sq / new_scale_exp_avg_sq);
+ // __half exp_avg_new = static_cast<__half>(float_exp_avg /
+ // new_scale_exp_avg);
+ // __half exp_avg_sq_new = static_cast<__half>(float_exp_avg_sq /
+ // new_scale_exp_avg_sq);
+
+ // printf("idx: %d, float: %f, quantize: %f\n", idx, float_exp_avg,
+ // (float)exp_avg_new * new_scale_exp_avg);
+
+ // store the output
+ exp_avg[idx] = exp_avg_new;
+ exp_avg_sq[idx] = exp_avg_sq_new;
+ scale_exp_avg[scale_idx] = new_scale_exp_avg;
+ scale_exp_avg_sq[scale_idx] = new_scale_exp_avg_sq;
+ expand_exp_avg[scale_idx] = firstExp;
+ expand_exp_avg_sq[scale_idx] = secondExp;
+ sqrtminmax_exp_avg[scale_idx] = firstSqrtMinMax;
+ sqrtminmax_exp_avg_sq[scale_idx] = secondSqrtMinMax;
+ }
+}
+
+template
+void myKernelLauncher(float *params, float *grads, __nv_fp8_e4m3 *exp_avg,
+ float *scale_exp_avg, float *expand_exp_avg,
+ float *sqrtminmax_exp_avg, __nv_fp8_e4m3 *exp_avg_sq,
+ float *scale_exp_avg_sq, float *expand_exp_avg_sq,
+ float *sqrtminmax_exp_avg_sq, float beta1, float beta2,
+ float lr, float wd, float eps, int step, int qgroup_size,
+ int expand_min, int M, int N) {
+ if (algo == fp8_adamw) {
+ const int block_dim = 128;
+ int grid_dim = (M * N + qgroup_size - 1) / block_dim;
+ const dim3 gridDim(grid_dim);
+ const dim3 blockDim(block_dim);
+ printf("Yes!\n");
+ fp8_adamw_csrc<<>>(
+ params, grads, exp_avg, scale_exp_avg, expand_exp_avg,
+ sqrtminmax_exp_avg, exp_avg_sq, scale_exp_avg_sq, expand_exp_avg_sq,
+ sqrtminmax_exp_avg_sq, beta1, beta2, lr, wd, eps, step, qgroup_size,
+ expand_min, M * N, int(floor(M * N / 128.)));
+ cudaError_t error = cudaGetLastError();
+ if (error != cudaSuccess) {
+ std::cout << "CUDA error occurred in kernel launch: "
+ << cudaGetErrorString(error) << std::endl;
+ return;
+ }
+ printf("Finish!\n");
+ }
+}
+
+float testMaxError(void (*myGPUKernel)(float *, float *, __nv_fp8_e4m3 *,
+ float *, float *, float *,
+ __nv_fp8_e4m3 *, float *, float *,
+ float *, // tensor input
+ float, float, float, float, float, int,
+ int, int, // float and int input
+ int, int), // M and N
+ int M, int N) {
+ size_t size_param = M * N * sizeof(float);
+ size_t size_optim = M * N * sizeof(__nv_fp8_e4m3);
+ size_t size_scale = int(ceil(M * N / 128.)) * sizeof(float);
+
+ // host tensor
+ float *h_p, *h_g;
+ __nv_fp8_e4m3 *h_m, *h_v;
+ float *h_sm, *h_sv;
+ float *h_fp_m, *h_fp_v;
+ float *h_cpd_m, *h_cpd_v;
+ float *h_sqrtmm_m, *h_sqrtmm_v;
+
+ // device tensor
+ float *d_p, *d_g;
+ __nv_fp8_e4m3 *d_m, *d_v;
+ float *d_sm, *d_sv;
+ float *d_cpd_m, *d_cpd_v;
+ float *d_sqrtmm_m, *d_sqrtmm_v;
+
+ // device tensor transfer to host
+ float *hd_p, *hd_g;
+ __nv_fp8_e4m3 *hd_m, *hd_v;
+ float *hd_sm, *hd_sv;
+ float *hd_fp_m, *hd_fp_v;
+ float *hd_cpd_m, *hd_cpd_v;
+ float *hd_sqrtmm_m, *hd_sqrtmm_v;
+
+ h_p = (float *)malloc(size_param);
+ h_g = (float *)malloc(size_param);
+ h_m = (__nv_fp8_e4m3 *)malloc(size_optim);
+ h_v = (__nv_fp8_e4m3 *)malloc(size_optim);
+ h_sm = (float *)malloc(size_scale);
+ h_sv = (float *)malloc(size_scale);
+ h_cpd_m = (float *)malloc(size_scale);
+ h_cpd_v = (float *)malloc(size_scale);
+ h_sqrtmm_m = (float *)malloc(size_scale);
+ h_sqrtmm_v = (float *)malloc(size_scale);
+ h_sv = (float *)malloc(size_scale);
+ h_fp_m = (float *)malloc(size_param);
+ h_fp_v = (float *)malloc(size_param);
+ cudaMalloc(&d_p, size_param);
+ cudaMalloc(&d_g, size_param);
+ cudaMalloc(&d_m, size_optim);
+ cudaMalloc(&d_v, size_optim);
+ cudaMalloc(&d_sm, size_scale);
+ cudaMalloc(&d_sv, size_scale);
+ cudaMalloc(&d_cpd_m, size_scale);
+ cudaMalloc(&d_cpd_v, size_scale);
+ cudaMalloc(&d_sqrtmm_m, size_scale);
+ cudaMalloc(&d_sqrtmm_v, size_scale);
+ hd_p = (float *)malloc(size_param);
+ hd_g = (float *)malloc(size_param);
+ hd_m = (__nv_fp8_e4m3 *)malloc(size_optim);
+ hd_v = (__nv_fp8_e4m3 *)malloc(size_optim);
+ hd_sm = (float *)malloc(size_scale);
+ hd_sv = (float *)malloc(size_scale);
+ hd_fp_m = (float *)malloc(size_param);
+ hd_fp_v = (float *)malloc(size_param);
+ hd_cpd_m = (float *)malloc(size_scale);
+ hd_cpd_v = (float *)malloc(size_scale);
+ hd_sqrtmm_m = (float *)malloc(size_scale);
+ hd_sqrtmm_v = (float *)malloc(size_scale);
+
+ cudaError_t error = cudaGetLastError();
+ if (error != cudaSuccess) {
+ std::cout << "CUDA error occurred in data copy: "
+ << cudaGetErrorString(error) << std::endl;
+ return 0.;
+ }
+
+ srand(0);
+ // random initialization for CPU tensor
+ for (int i = 0; i < M * N; i++) {
+ h_p[i] = (float)(rand() / (float(RAND_MAX) / 10));
+ h_g[i] = (float)(rand() / (float(RAND_MAX) / 10));
+ h_m[i] = (__nv_fp8_e4m3)(rand() / (float(RAND_MAX) / 10));
+ h_v[i] = (__nv_fp8_e4m3)(rand() / (float(RAND_MAX) / 10));
+ }
+ for (int i = 0; i < int(ceilf(M * N / 128.)); i++) {
+ h_sm[i] = (float)(rand() / (float(RAND_MAX) / 10));
+ h_sv[i] = (float)(rand() / (float(RAND_MAX) / 10));
+ h_cpd_m[i] = 2;
+ h_cpd_v[i] = 3. / 8.;
+ h_sqrtmm_m[i] = (float)(rand() / (float(RAND_MAX) / 10));
+ h_sqrtmm_v[i] = (float)(rand() / (float(RAND_MAX) / 10));
+
+ printf("scale is %f\n", h_sm[i]);
+ }
+ for (int i = 0; i < M * N; i++) {
+ h_fp_m[i] = (float)h_m[i] * h_sm[int(floor(i / 128.))];
+ h_fp_v[i] = (float)h_v[i] * h_sv[int(floor(i / 128.))];
+ }
+ float beta1 = 0.9, beta2 = 0.95, lr = 4e-4, wd = 0.1, eps = 1e-8;
+ int step = 100, qgroup_size = 128, expand_min = 16;
+
+ printFloatArrayToFile(h_p, M, N, "Past_CPU_param.txt");
+ printFloatArrayToFile(h_g, M, N, "Past_CPU_grad.txt");
+ printFloatArrayToFile(h_m, M, N, "Past_CPU_m1.txt");
+ printFloatArrayToFile(h_sm, 1, int(ceilf(M * N / 128.)), "Past_CPU_ms.txt");
+ printFloatArrayToFile(h_fp_m, M, N, "Past_CPU_mf.txt");
+ printFloatArrayToFile(h_v, M, N, "Past_CPU_v2.txt");
+ printFloatArrayToFile(h_sv, 1, int(ceilf(M * N / 128.)), "Past_CPU_vs.txt");
+ printFloatArrayToFile(h_fp_v, M, N, "Past_CPU_vf.txt");
+
+ cudaMemcpy(d_p, h_p, size_param, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_g, h_g, size_param, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_m, h_m, size_optim, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_v, h_v, size_optim, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_sm, h_sm, size_scale, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_sv, h_sv, size_scale, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_cpd_m, h_cpd_m, size_scale, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_cpd_v, h_cpd_v, size_scale, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_sqrtmm_m, h_sqrtmm_m, size_scale, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_sqrtmm_v, h_sqrtmm_v, size_scale, cudaMemcpyHostToDevice);
+
+ fp8_adamw_cpu(h_p, h_g, h_fp_m, h_fp_v, beta1, beta2, lr, wd, eps, step,
+ qgroup_size, M, N);
+
+ if (error != cudaSuccess) {
+ std::cout << "CUDA error occurred in data initialization: "
+ << cudaGetErrorString(error) << std::endl;
+ return 0.;
+ }
+
+ myGPUKernel(d_p, d_g, d_m, d_sm, d_cpd_m, d_sqrtmm_m, d_v, d_sv, d_cpd_v,
+ d_sqrtmm_v, beta1, beta2, lr, wd, eps, step, qgroup_size,
+ expand_min, M, N);
+
+ cudaMemcpy(hd_p, d_p, size_param, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_g, d_g, size_param, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_m, d_m, size_optim, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_v, d_v, size_optim, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_sm, d_sm, size_scale, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_sv, d_sv, size_scale, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_cpd_m, d_cpd_m, size_scale, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_cpd_v, d_cpd_v, size_scale, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_sqrtmm_m, d_sqrtmm_m, size_scale, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_sqrtmm_v, d_sqrtmm_v, size_scale, cudaMemcpyDeviceToHost);
+
+ for (int i = 0; i < M * N; i++) {
+ hd_fp_m[i] = pow((float)hd_m[i] * hd_sm[int(floor(i / 128.))],
+ 1 / hd_cpd_m[int(floor(i / 128.))]) *
+ hd_sqrtmm_m[int(floor(i / 128.))];
+ hd_fp_v[i] = pow((float)hd_v[i] * hd_sv[int(floor(i / 128.))],
+ 1 / hd_cpd_v[int(floor(i / 128.))]) *
+ hd_sqrtmm_v[int(floor(i / 128.))];
+ }
+ printFloatArrayToFile(h_p, M, N, "CPU_param.txt");
+ printFloatArrayToFile(hd_p, M, N, "GPU_param.txt");
+ printFloatArrayToFile(h_g, M, N, "CPU_grad.txt");
+ printFloatArrayToFile(hd_g, M, N, "GPU_grad.txt");
+ printFloatArrayToFile(h_m, M, N, "CPU_m1.txt");
+ printFloatArrayToFile(h_sm, 1, int(ceilf(M * N / 128.)), "CPU_ms.txt");
+ printFloatArrayToFile(h_fp_m, M, N, "CPU_mf.txt");
+ printFloatArrayToFile(hd_m, M, N, "GPU_m1.txt");
+ printFloatArrayToFile(hd_sm, 1, int(ceilf(M * N / 128.)), "GPU_ms.txt");
+ printFloatArrayToFile(hd_fp_m, M, N, "GPU_mf.txt");
+ printFloatArrayToFile(h_v, M, N, "CPU_v2.txt");
+ printFloatArrayToFile(h_sv, 1, int(ceilf(M * N / 128.)), "CPU_vs.txt");
+ printFloatArrayToFile(h_fp_v, M, N, "CPU_vf.txt");
+ printFloatArrayToFile(hd_v, M, N, "GPU_v2.txt");
+ printFloatArrayToFile(hd_sv, 1, int(ceilf(M * N / 128.)), "GPU_vs.txt");
+ printFloatArrayToFile(hd_fp_v, M, N, "GPU_vf.txt");
+
+ printFloatArrayToFile(hd_cpd_m, 1, int(ceilf(M * N / 128.)), "GPU_cpd_m.txt");
+ printFloatArrayToFile(hd_cpd_v, 1, int(ceilf(M * N / 128.)), "GPU_cpd_v.txt");
+ printFloatArrayToFile(hd_sqrtmm_m, 1, int(ceilf(M * N / 128.)),
+ "GPU_sqrtmm_m.txt");
+ printFloatArrayToFile(hd_sqrtmm_v, 1, int(ceilf(M * N / 128.)),
+ "GPU_sqrtmm_v.txt");
+
+ return 0.;
+}
+
+int main() {
+ const int M = 1, N = 7;
+ float max_error = testMaxError(myKernelLauncher, M, N);
+}
diff --git a/llava/model/coat/optimizer/kernels/csrc_origin_quantize/makefile b/llava/model/coat/optimizer/kernels/csrc_origin_quantize/makefile
new file mode 100644
index 0000000000000000000000000000000000000000..ad4632ffe641f2043c2cbc4966dff66af2807c26
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/csrc_origin_quantize/makefile
@@ -0,0 +1,6 @@
+all:
+ nvcc nvcc_qoptim.cu -o nvcc_qoptim -gencode=arch=compute_90,code=compute_90
+run:
+ ./nvcc_qoptim
+clean:
+ rm -f nvcc_qoptim
diff --git a/llava/model/coat/optimizer/kernels/csrc_origin_quantize/nvcc_qoptim.cu b/llava/model/coat/optimizer/kernels/csrc_origin_quantize/nvcc_qoptim.cu
new file mode 100644
index 0000000000000000000000000000000000000000..dde952654b71f11dc6de051a7089df38e14c6b57
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/csrc_origin_quantize/nvcc_qoptim.cu
@@ -0,0 +1,354 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+
+namespace cg = cooperative_groups;
+#define WARPSIZE 32
+#define QGROUPSIZE 128
+#define QUANT_MIN_VAL 1e-20
+
+template
+inline float fp8_dtype_max(const T &variable) {
+ if (std::is_same::value) {
+ return 448;
+ } else if (std::is_same::value) {
+ return 57344;
+ } else {
+ throw "Unsupported data format";
+ }
+}
+
+typedef enum { fp8_adamw } myCsrcKernels;
+
+void fp8_adamw_cpu(float *params, float *grads, float *fp_exp_avg,
+ float *fp_exp_avg_sq, float beta1, float beta2, float lr,
+ float wd, float eps, int step, int qgroup_size, int M,
+ int N) {
+ for (int idx = 0; idx < M * N; idx++) {
+ fp_exp_avg[idx] = beta1 * fp_exp_avg[idx] + (1 - beta1) * grads[idx];
+ fp_exp_avg_sq[idx] =
+ beta2 * fp_exp_avg_sq[idx] + (1 - beta2) * grads[idx] * grads[idx];
+
+ const float correction1 = 1.0f - powf(beta1, step);
+ const float correction2_sqrt = sqrtf(1.0f - powf(beta2, step));
+
+ float denom =
+ (sqrtf(fp_exp_avg_sq[idx]) / correction2_sqrt + eps) * correction1;
+ float update = (fp_exp_avg[idx] / denom) + (wd * params[idx]);
+ params[idx] = params[idx] - (lr * update);
+ }
+}
+
+template
+void printFloatArrayToFile(T *array, int M, int N,
+ const std::string &outputFileName) {
+ std::ofstream outputFile(outputFileName);
+ if (!outputFile.is_open()) {
+ std::cout << "Failed to open the file." << std::endl;
+ return;
+ }
+
+ for (int i = 0; i < M; i++) {
+ for (int j = 0; j < N; j++) {
+ int index = i * N + j;
+ outputFile << std::setw(10) << std::right << std::fixed
+ << std::setprecision(6) << (float)array[index] << " ";
+ if (j == N - 1) {
+ outputFile << "\n";
+ }
+ }
+ }
+}
+
+template
+__global__ void fp8_adamw_csrc(scalar_t *__restrict__ params,
+ scalar_t *__restrict__ grads,
+ __nv_fp8_e4m3 *__restrict__ exp_avg,
+ float *__restrict__ scale_exp_avg,
+ __nv_fp8_e4m3 *__restrict__ exp_avg_sq,
+ float *__restrict__ scale_exp_avg_sq,
+ float beta1, float beta2, float lr, float wd,
+ float eps, int step, int qgroup_size,
+ int total_elements, int total_scale_elements) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int scale_idx = blockIdx.x;
+
+ float float_exp_avg, float_exp_avg_sq;
+ float correction1, correction2_sqrt;
+ float denom, update;
+
+ if (idx < total_elements) {
+ // dequantize the optimizer states
+ float_exp_avg = float(exp_avg[idx]) * scale_exp_avg[scale_idx];
+ float_exp_avg_sq = float(exp_avg_sq[idx]) * scale_exp_avg_sq[scale_idx];
+
+ // calculation of optimizer.step()
+ float_exp_avg = beta1 * float_exp_avg + (1 - beta1) * grads[idx];
+ float_exp_avg_sq =
+ beta2 * float_exp_avg_sq + (1 - beta2) * grads[idx] * grads[idx];
+
+ correction1 = 1.0f - powf(beta1, step);
+ correction2_sqrt = sqrtf(1.0f - powf(beta2, step));
+
+ denom = (sqrtf(float_exp_avg_sq) / correction2_sqrt + eps) * correction1;
+ update = (float_exp_avg / denom) + (wd * params[idx]);
+
+ params[idx] = params[idx] - (lr * update);
+ } else {
+ float_exp_avg = 0.0f;
+ float_exp_avg_sq = 0.0f;
+ }
+
+ //// quantize the first-order and second-order momentum
+ int wid = threadIdx.x / WARPSIZE;
+
+ // reduction within a warp
+
+ __shared__ float sharedFirstMaxVal[32];
+ __shared__ float sharedSecondMaxVal[32];
+ cg::thread_block_tile<32> warpTile =
+ cg::tiled_partition<32>(cg::this_thread_block());
+ float firstMaxVal = fabsf(float_exp_avg);
+ float secondMaxVal = fabsf(float_exp_avg_sq);
+
+ for (int i = warpTile.size() / 2; i > 0; i /= 2) {
+ float reduceFirstMaxVal = warpTile.shfl_down(firstMaxVal, i);
+ float reduceSecondMaxVal = warpTile.shfl_down(secondMaxVal, i);
+ firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal));
+ secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal));
+ // printf("First Max: %f\n", reduceFirstMaxVal);
+ }
+ int lane = warpTile.thread_rank();
+ if (lane == 0) sharedFirstMaxVal[wid] = firstMaxVal;
+ if (lane == 0) sharedSecondMaxVal[wid] = secondMaxVal;
+
+ __syncthreads();
+
+ // reduction within a block
+ __shared__ float shared_absmax_exp_avg;
+ __shared__ float shared_absmax_exp_avg_sq;
+ firstMaxVal =
+ (threadIdx.x < blockDim.x / warpSize) ? sharedFirstMaxVal[lane] : 0;
+ secondMaxVal =
+ (threadIdx.x < blockDim.x / warpSize) ? sharedSecondMaxVal[lane] : 0;
+ if (wid == 0) {
+ for (int offset = WARPSIZE / 2; offset > 0; offset /= 2) {
+ float reduceFirstMaxVal =
+ __shfl_down_sync(0xFFFFFFFF, firstMaxVal, offset);
+ float reduceSecondMaxVal =
+ __shfl_down_sync(0xFFFFFFFF, secondMaxVal, offset);
+ firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal));
+ secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal));
+ }
+ if (lane == 0) shared_absmax_exp_avg = firstMaxVal;
+ if (lane == 0) shared_absmax_exp_avg_sq = secondMaxVal;
+ }
+
+ __syncthreads();
+
+ if (idx < total_elements) {
+ // float fp8MaxVal = fp8_dtype_max<__nv_fp8_e4m3>(exp_avg[idx]);
+ float fp8MaxVal = 448;
+
+ shared_absmax_exp_avg = shared_absmax_exp_avg + QUANT_MIN_VAL;
+ shared_absmax_exp_avg_sq = shared_absmax_exp_avg_sq + QUANT_MIN_VAL;
+
+ float new_scale_exp_avg = shared_absmax_exp_avg / fp8MaxVal;
+ float new_scale_exp_avg_sq = shared_absmax_exp_avg_sq / fp8MaxVal;
+
+ // quantize the optimizer states
+ __nv_fp8_e4m3 exp_avg_new =
+ static_cast<__nv_fp8_e4m3>(float_exp_avg / new_scale_exp_avg);
+ __nv_fp8_e4m3 exp_avg_sq_new =
+ static_cast<__nv_fp8_e4m3>(float_exp_avg_sq / new_scale_exp_avg_sq);
+ // __half exp_avg_new = static_cast<__half>(float_exp_avg /
+ // new_scale_exp_avg);
+ // __half exp_avg_sq_new = static_cast<__half>(float_exp_avg_sq /
+ // new_scale_exp_avg_sq);
+
+ // printf("idx: %d, float: %f, quantize: %f\n", idx, float_exp_avg,
+ // (float)exp_avg_new * new_scale_exp_avg);
+
+ // store the output
+ exp_avg[idx] = exp_avg_new;
+ exp_avg_sq[idx] = exp_avg_sq_new;
+ scale_exp_avg[scale_idx] = new_scale_exp_avg;
+ scale_exp_avg_sq[scale_idx] = new_scale_exp_avg_sq;
+ }
+}
+
+template
+void myKernelLauncher(float *params, float *grads, __nv_fp8_e4m3 *exp_avg,
+ float *scale_exp_avg, __nv_fp8_e4m3 *exp_avg_sq,
+ float *scale_exp_avg_sq, float beta1, float beta2,
+ float lr, float wd, float eps, int step, int qgroup_size,
+ int M, int N) {
+ if (algo == fp8_adamw) {
+ const int block_dim = 128;
+ int grid_dim = (M * N + qgroup_size - 1) / block_dim;
+ const dim3 gridDim(grid_dim);
+ const dim3 blockDim(block_dim);
+ printf("Yes!\n");
+ fp8_adamw_csrc<<>>(
+ params, grads, exp_avg, scale_exp_avg, exp_avg_sq, scale_exp_avg_sq,
+ beta1, beta2, lr, wd, eps, step, qgroup_size, M * N,
+ int(floor(M * N / 128.)));
+ cudaError_t error = cudaGetLastError();
+ if (error != cudaSuccess) {
+ std::cout << "CUDA error occurred in kernel launch: "
+ << cudaGetErrorString(error) << std::endl;
+ return;
+ }
+ printf("Finish!\n");
+ }
+}
+
+float testMaxError(void (*myGPUKernel)(float *, float *, __nv_fp8_e4m3 *,
+ float *, __nv_fp8_e4m3 *, float *, float,
+ float, float, float, float, int, int,
+ int, int),
+ int M, int N) {
+ size_t size_param = M * N * sizeof(float);
+ size_t size_optim = M * N * sizeof(__nv_fp8_e4m3);
+ size_t size_scale = int(ceil(M * N / 128.)) * sizeof(float);
+
+ // host tensor
+ float *h_p, *h_g;
+ __nv_fp8_e4m3 *h_m, *h_v;
+ float *h_sm, *h_sv;
+ float *h_fp_m, *h_fp_v;
+
+ // device tensor
+ float *d_p, *d_g;
+ __nv_fp8_e4m3 *d_m, *d_v;
+ float *d_sm, *d_sv;
+
+ // device tensor transfer to host
+ float *hd_p, *hd_g;
+ __nv_fp8_e4m3 *hd_m, *hd_v;
+ float *hd_sm, *hd_sv;
+ float *hd_fp_m, *hd_fp_v;
+
+ h_p = (float *)malloc(size_param);
+ h_g = (float *)malloc(size_param);
+ h_m = (__nv_fp8_e4m3 *)malloc(size_optim);
+ h_v = (__nv_fp8_e4m3 *)malloc(size_optim);
+ h_sm = (float *)malloc(size_scale);
+ h_sv = (float *)malloc(size_scale);
+ h_fp_m = (float *)malloc(size_param);
+ h_fp_v = (float *)malloc(size_param);
+ cudaMalloc(&d_p, size_param);
+ cudaMalloc(&d_g, size_param);
+ cudaMalloc(&d_m, size_optim);
+ cudaMalloc(&d_v, size_optim);
+ cudaMalloc(&d_sm, size_scale);
+ cudaMalloc(&d_sv, size_scale);
+ hd_p = (float *)malloc(size_param);
+ hd_g = (float *)malloc(size_param);
+ hd_m = (__nv_fp8_e4m3 *)malloc(size_optim);
+ hd_v = (__nv_fp8_e4m3 *)malloc(size_optim);
+ hd_sm = (float *)malloc(size_scale);
+ hd_sv = (float *)malloc(size_scale);
+ hd_fp_m = (float *)malloc(size_param);
+ hd_fp_v = (float *)malloc(size_param);
+
+ cudaError_t error = cudaGetLastError();
+ if (error != cudaSuccess) {
+ std::cout << "CUDA error occurred in data copy: "
+ << cudaGetErrorString(error) << std::endl;
+ return 0.;
+ }
+
+ srand(0);
+ // random initialization for CPU tensor
+ for (int i = 0; i < M * N; i++) {
+ h_p[i] = (float)(rand() / (float(RAND_MAX) / 10));
+ h_g[i] = (float)(rand() / (float(RAND_MAX) / 10));
+ h_m[i] = (__nv_fp8_e4m3)(rand() / (float(RAND_MAX) / 10));
+ h_v[i] = (__nv_fp8_e4m3)(rand() / (float(RAND_MAX) / 10));
+ }
+ for (int i = 0; i < int(ceilf(M * N / 128.)); i++) {
+ h_sm[i] = (float)(rand() / (float(RAND_MAX) / 10));
+ h_sv[i] = (float)(rand() / (float(RAND_MAX) / 10));
+ printf("scale is %f\n", h_sm[i]);
+ }
+ for (int i = 0; i < M * N; i++) {
+ h_fp_m[i] = (float)h_m[i] * h_sm[int(floor(i / 128.))];
+ h_fp_v[i] = (float)h_v[i] * h_sv[int(floor(i / 128.))];
+ }
+ float beta1 = 0.9, beta2 = 0.95, lr = 4e-4, wd = 0.1, eps = 1e-8;
+ int step = 100, qgroup_size = 128;
+
+ printFloatArrayToFile(h_p, M, N, "Past_CPU_param.txt");
+ printFloatArrayToFile(h_g, M, N, "Past_CPU_grad.txt");
+ printFloatArrayToFile(h_m, M, N, "Past_CPU_m1.txt");
+ printFloatArrayToFile(h_sm, 1, int(ceilf(M * N / 128.)), "Past_CPU_ms.txt");
+ printFloatArrayToFile(h_fp_m, M, N, "Past_CPU_mf.txt");
+ printFloatArrayToFile(h_v, M, N, "Past_CPU_v2.txt");
+ printFloatArrayToFile(h_sv, 1, int(ceilf(M * N / 128.)), "Past_CPU_vs.txt");
+ printFloatArrayToFile(h_fp_v, M, N, "Past_CPU_vf.txt");
+
+ cudaMemcpy(d_p, h_p, size_param, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_g, h_g, size_param, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_m, h_m, size_optim, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_v, h_v, size_optim, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_sm, h_sm, size_scale, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_sv, h_sv, size_scale, cudaMemcpyHostToDevice);
+
+ fp8_adamw_cpu(h_p, h_g, h_fp_m, h_fp_v, beta1, beta2, lr, wd, eps, step,
+ qgroup_size, M, N);
+
+ if (error != cudaSuccess) {
+ std::cout << "CUDA error occurred in data initialization: "
+ << cudaGetErrorString(error) << std::endl;
+ return 0.;
+ }
+
+ myGPUKernel(d_p, d_g, d_m, d_sm, d_v, d_sv, beta1, beta2, lr, wd, eps, step,
+ qgroup_size, M, N);
+
+ cudaMemcpy(hd_p, d_p, size_param, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_g, d_g, size_param, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_m, d_m, size_optim, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_v, d_v, size_optim, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_sm, d_sm, size_scale, cudaMemcpyDeviceToHost);
+ cudaMemcpy(hd_sv, d_sv, size_scale, cudaMemcpyDeviceToHost);
+
+ for (int i = 0; i < M * N; i++) {
+ hd_fp_m[i] = (float)hd_m[i] * hd_sm[int(floor(i / 128.))];
+ hd_fp_v[i] = (float)hd_v[i] * hd_sv[int(floor(i / 128.))];
+ }
+ printFloatArrayToFile(h_p, M, N, "CPU_param.txt");
+ printFloatArrayToFile(hd_p, M, N, "GPU_param.txt");
+ printFloatArrayToFile(h_g, M, N, "CPU_grad.txt");
+ printFloatArrayToFile(hd_g, M, N, "GPU_grad.txt");
+ printFloatArrayToFile(h_m, M, N, "CPU_m1.txt");
+ printFloatArrayToFile(h_sm, 1, int(ceilf(M * N / 128.)), "CPU_ms.txt");
+ printFloatArrayToFile(h_fp_m, M, N, "CPU_mf.txt");
+ printFloatArrayToFile(hd_m, M, N, "GPU_m1.txt");
+ printFloatArrayToFile(hd_sm, 1, int(ceilf(M * N / 128.)), "GPU_ms.txt");
+ printFloatArrayToFile(hd_fp_m, M, N, "GPU_mf.txt");
+ printFloatArrayToFile(h_v, M, N, "CPU_v2.txt");
+ printFloatArrayToFile(h_sv, 1, int(ceilf(M * N / 128.)), "CPU_vs.txt");
+ printFloatArrayToFile(h_fp_v, M, N, "CPU_vf.txt");
+ printFloatArrayToFile(hd_v, M, N, "GPU_v2.txt");
+ printFloatArrayToFile(hd_sv, 1, int(ceilf(M * N / 128.)), "GPU_vs.txt");
+ printFloatArrayToFile(hd_fp_v, M, N, "GPU_vf.txt");
+
+ return 0.;
+}
+
+int main() {
+ const int M = 1, N = 7;
+ float max_error = testMaxError(myKernelLauncher, M, N);
+}
diff --git a/llava/model/coat/optimizer/kernels/fp8_adamw_cuda.cpp b/llava/model/coat/optimizer/kernels/fp8_adamw_cuda.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..57190c713309a19eafe35b4651e55d7c5bd576f7
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/fp8_adamw_cuda.cpp
@@ -0,0 +1,26 @@
+#include
+#include
+
+void FP8_AdamW_cuda(torch::Tensor params, // parameter
+ torch::Tensor grads, // gradient
+ torch::Tensor exp_avg, // first order momentum
+ torch::Tensor scale_exp_avg,
+ torch::Tensor exp_avg_sq, // second order momentum
+ torch::Tensor scale_exp_avg_sq, float beta1, float beta2,
+ float lr, float wd, float eps, int step,
+ int qgroup_size // other parameters
+);
+
+void FP8_AdamW(torch::Tensor params, // parameter
+ torch::Tensor grads, // gradient
+ torch::Tensor exp_avg, // first order momentum
+ torch::Tensor scale_exp_avg,
+ torch::Tensor exp_avg_sq, // second order momentum
+ torch::Tensor scale_exp_avg_sq, float beta1, float beta2,
+ float lr, float wd, float eps, int step,
+ int qgroup_size) { // other parameters
+
+ FP8_AdamW_cuda(params, grads, exp_avg, scale_exp_avg, exp_avg_sq,
+ scale_exp_avg_sq, beta1, beta2, lr, wd, eps, step,
+ qgroup_size);
+}
diff --git a/llava/model/coat/optimizer/kernels/fp8_adamw_cuda_kernel.cu b/llava/model/coat/optimizer/kernels/fp8_adamw_cuda_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..9e652845af5d228237d3c8df758df1a4fe94bcd4
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/fp8_adamw_cuda_kernel.cu
@@ -0,0 +1,163 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#define QUANT_MIN_VAL 1e-20
+
+namespace cg = cooperative_groups;
+#define WARPSIZE 32
+
+template
+__global__ void fp8_adamw_cuda_kernel(
+ scalar_t* __restrict__ params, scalar_t* __restrict__ grads,
+ __nv_fp8_e4m3* __restrict__ exp_avg, float* __restrict__ scale_exp_avg,
+ __nv_fp8_e4m3* __restrict__ exp_avg_sq,
+ float* __restrict__ scale_exp_avg_sq, float beta1, float beta2, float lr,
+ float wd, float eps, int step, int qgroup_size, int total_elements,
+ int total_scale_elements) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int scale_idx = blockIdx.x;
+
+ float float_exp_avg, float_exp_avg_sq;
+ float correction1, correction2_sqrt;
+ float denom, update;
+
+ if (idx < total_elements) {
+ // dequantize the optimizer states
+ float_exp_avg = float(exp_avg[idx]) * scale_exp_avg[scale_idx];
+ float_exp_avg_sq = float(exp_avg_sq[idx]) * scale_exp_avg_sq[scale_idx];
+
+ // calculation of optimizer.step()
+ float_exp_avg = beta1 * float_exp_avg + (1 - beta1) * grads[idx];
+ float_exp_avg_sq =
+ beta2 * float_exp_avg_sq + (1 - beta2) * grads[idx] * grads[idx];
+
+ correction1 = 1.0f - powf(beta1, step);
+ correction2_sqrt = sqrtf(1.0f - powf(beta2, step));
+
+ denom = (sqrtf(float_exp_avg_sq) / correction2_sqrt + eps) * correction1;
+ update = (float_exp_avg / denom) + (wd * params[idx]);
+
+ params[idx] = params[idx] - (lr * update);
+ } else {
+ float_exp_avg = 0.0f;
+ float_exp_avg_sq = 0.0f;
+ }
+
+ //// quantize the first-order and second-order momentum
+ int wid = threadIdx.x / WARPSIZE;
+
+ // reduction within a warp
+
+ __shared__ float sharedFirstMaxVal[32];
+ __shared__ float sharedSecondMaxVal[32];
+ cg::thread_block_tile<32> warpTile =
+ cg::tiled_partition<32>(cg::this_thread_block());
+ float firstMaxVal = fabsf(float_exp_avg);
+ float secondMaxVal = fabsf(float_exp_avg_sq);
+
+ for (int i = warpTile.size() / 2; i > 0; i /= 2) {
+ float reduceFirstMaxVal = warpTile.shfl_down(firstMaxVal, i);
+ float reduceSecondMaxVal = warpTile.shfl_down(secondMaxVal, i);
+ firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal));
+ secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal));
+ // printf("First Max: %f\n", reduceFirstMaxVal);
+ }
+ int lane = warpTile.thread_rank();
+ if (lane == 0) sharedFirstMaxVal[wid] = firstMaxVal;
+ if (lane == 0) sharedSecondMaxVal[wid] = secondMaxVal;
+
+ __syncthreads();
+
+ // reduction within a block
+ __shared__ float shared_absmax_exp_avg;
+ __shared__ float shared_absmax_exp_avg_sq;
+ firstMaxVal =
+ (threadIdx.x < blockDim.x / warpSize) ? sharedFirstMaxVal[lane] : 0;
+ secondMaxVal =
+ (threadIdx.x < blockDim.x / warpSize) ? sharedSecondMaxVal[lane] : 0;
+ if (wid == 0) {
+ for (int offset = WARPSIZE / 2; offset > 0; offset /= 2) {
+ float reduceFirstMaxVal =
+ __shfl_down_sync(0xFFFFFFFF, firstMaxVal, offset);
+ float reduceSecondMaxVal =
+ __shfl_down_sync(0xFFFFFFFF, secondMaxVal, offset);
+ firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal));
+ secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal));
+ }
+ if (lane == 0) shared_absmax_exp_avg = firstMaxVal;
+ if (lane == 0) shared_absmax_exp_avg_sq = secondMaxVal;
+ }
+
+ __syncthreads();
+
+ if (idx < total_elements) {
+ // float fp8MaxVal = fp8_dtype_max<__nv_fp8_e4m3>(exp_avg[idx]);
+ float fp8MaxVal = 448;
+
+ shared_absmax_exp_avg = shared_absmax_exp_avg + QUANT_MIN_VAL;
+ shared_absmax_exp_avg_sq = shared_absmax_exp_avg_sq + QUANT_MIN_VAL;
+
+ float new_scale_exp_avg = shared_absmax_exp_avg / fp8MaxVal;
+ float new_scale_exp_avg_sq = shared_absmax_exp_avg_sq / fp8MaxVal;
+
+ // quantize the optimizer states
+ __nv_fp8_e4m3 exp_avg_new =
+ static_cast<__nv_fp8_e4m3>(float_exp_avg / new_scale_exp_avg);
+ __nv_fp8_e4m3 exp_avg_sq_new =
+ static_cast<__nv_fp8_e4m3>(float_exp_avg_sq / new_scale_exp_avg_sq);
+ // __half exp_avg_new = static_cast<__half>(float_exp_avg /
+ // new_scale_exp_avg);
+ // __half exp_avg_sq_new = static_cast<__half>(float_exp_avg_sq /
+ // new_scale_exp_avg_sq);
+
+ // printf("idx: %d, float: %f, quantize: %f\n", idx, float_exp_avg,
+ // (float)exp_avg_new * new_scale_exp_avg);
+
+ // store the output
+ exp_avg[idx] = exp_avg_new;
+ exp_avg_sq[idx] = exp_avg_sq_new;
+ scale_exp_avg[scale_idx] = new_scale_exp_avg;
+ scale_exp_avg_sq[scale_idx] = new_scale_exp_avg_sq;
+ }
+}
+
+void FP8_AdamW_cuda(torch::Tensor params, // parameter
+ torch::Tensor grads, // gradient
+ torch::Tensor exp_avg, // first order momentum
+ torch::Tensor scale_exp_avg,
+ torch::Tensor exp_avg_sq, // second order momentum
+ torch::Tensor scale_exp_avg_sq, float beta1, float beta2,
+ float lr, float wd, float eps, int step,
+ int qgroup_size) { // other parameters
+
+ // CUDA Blocks
+ int total_elements = params.numel();
+ int total_scale_elements = scale_exp_avg.numel();
+ AT_ASSERTM(qgroup_size == 128,
+ "Only Support 128 per-group quantization currently");
+ const int block_dim = 128; // This should equal to the qgroup_size
+ int grid_dim = (total_elements + qgroup_size - 1) / block_dim;
+ AT_ASSERTM(grid_dim == scale_exp_avg.numel());
+ AT_ASSERTM(grid_dim == scale_exp_avg_sq.numel());
+ const dim3 blocks(grid_dim);
+
+ // Execution
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ at::kBFloat16, at::kHalf, params.scalar_type(), "fp8_adamw", ([&] {
+ fp8_adamw_cuda_kernel<<>>(
+ params.data_ptr(), grads.data_ptr(),
+ (__nv_fp8_e4m3*)exp_avg.data_ptr(),
+ scale_exp_avg.data_ptr(),
+ (__nv_fp8_e4m3*)exp_avg_sq.data_ptr(),
+ scale_exp_avg_sq.data_ptr(), beta1, beta2, lr, wd, eps, step,
+ qgroup_size, total_elements, total_scale_elements);
+ }));
+}
diff --git a/llava/model/coat/optimizer/kernels/fp8_adamw_expand_cuda.cpp b/llava/model/coat/optimizer/kernels/fp8_adamw_expand_cuda.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..d50f96dc9d4dc186b660d6b6ff56ff9df1c2df89
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/fp8_adamw_expand_cuda.cpp
@@ -0,0 +1,34 @@
+#include
+#include
+
+void FP8_AdamW_expand_cuda(torch::Tensor params, // parameter
+ torch::Tensor grads, // gradient
+ torch::Tensor exp_avg, // first order momentum
+ torch::Tensor scale_exp_avg,
+ torch::Tensor expand_exp_avg,
+ torch::Tensor sqrtminmax_exp_avg,
+ torch::Tensor exp_avg_sq, // second order momentum
+ torch::Tensor scale_exp_avg_sq,
+ torch::Tensor expand_exp_avg_sq,
+ torch::Tensor sqrtminmax_exp_avg_sq, float beta1,
+ float beta2, float lr, float wd, float eps, int step,
+ int qgroup_size, int expand_min // other parameters
+);
+
+void FP8_AdamW_expand(torch::Tensor params, // parameter
+ torch::Tensor grads, // gradient
+ torch::Tensor exp_avg, // first order momentum
+ torch::Tensor scale_exp_avg, torch::Tensor expand_exp_avg,
+ torch::Tensor sqrtminmax_exp_avg,
+ torch::Tensor exp_avg_sq, // second order momentum
+ torch::Tensor scale_exp_avg_sq,
+ torch::Tensor expand_exp_avg_sq,
+ torch::Tensor sqrtminmax_exp_avg_sq, float beta1,
+ float beta2, float lr, float wd, float eps, int step,
+ int qgroup_size, int expand_min) { // other parameters
+
+ FP8_AdamW_expand_cuda(params, grads, exp_avg, scale_exp_avg, expand_exp_avg,
+ sqrtminmax_exp_avg, exp_avg_sq, scale_exp_avg_sq,
+ expand_exp_avg_sq, sqrtminmax_exp_avg_sq, beta1, beta2,
+ lr, wd, eps, step, qgroup_size, expand_min);
+}
diff --git a/llava/model/coat/optimizer/kernels/fp8_adamw_expand_cuda_kernel.cu b/llava/model/coat/optimizer/kernels/fp8_adamw_expand_cuda_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..6130e3716fe5d73bf66cc6f915c3bf99a41f785b
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/fp8_adamw_expand_cuda_kernel.cu
@@ -0,0 +1,247 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#define QUANT_MIN_VAL 1e-20
+
+namespace cg = cooperative_groups;
+#define WARPSIZE 32
+
+template
+__global__ void fp8_adamw_cuda_expand_kernel(
+ scalar_t* __restrict__ params, scalar_t* __restrict__ grads,
+ __nv_fp8_e4m3* __restrict__ exp_avg, float* __restrict__ scale_exp_avg,
+ float* __restrict__ expand_exp_avg, float* __restrict__ sqrtminmax_exp_avg,
+ __nv_fp8_e4m3* __restrict__ exp_avg_sq,
+ float* __restrict__ scale_exp_avg_sq, float* __restrict__ expand_exp_avg_sq,
+ float* __restrict__ sqrtminmax_exp_avg_sq, float beta1, float beta2,
+ float lr, float wd, float eps, int step, int qgroup_size, int expand_min,
+ int total_elements, int total_scale_elements) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int scale_idx = blockIdx.x;
+
+ float float_exp_avg, float_exp_avg_sq;
+ float correction1, correction2_sqrt;
+ float denom, update;
+
+ if (idx < total_elements) {
+ // dequantize the optimizer states
+ float_exp_avg = float(exp_avg[idx]) * scale_exp_avg[scale_idx];
+ int sign_exp_avg = 1 - 2 * signbit(float_exp_avg);
+ float_exp_avg = sign_exp_avg *
+ powf(fabsf(float_exp_avg), 1 / expand_exp_avg[scale_idx]) *
+ sqrtminmax_exp_avg[scale_idx];
+ float_exp_avg_sq = float(exp_avg_sq[idx]) * scale_exp_avg_sq[scale_idx];
+ float_exp_avg_sq =
+ powf(float_exp_avg_sq, 1 / expand_exp_avg_sq[scale_idx]) *
+ sqrtminmax_exp_avg_sq[scale_idx];
+
+ // calculation of optimizer.step()
+ float_exp_avg = beta1 * float_exp_avg + (1 - beta1) * grads[idx];
+ float_exp_avg_sq =
+ beta2 * float_exp_avg_sq + (1 - beta2) * grads[idx] * grads[idx];
+
+ correction1 = 1.0f - powf(beta1, step);
+ correction2_sqrt = sqrtf(1.0f - powf(beta2, step));
+
+ denom = (sqrtf(float_exp_avg_sq) / correction2_sqrt + eps) * correction1;
+ update = (float_exp_avg / denom) + (wd * params[idx]);
+ params[idx] = params[idx] - (lr * update);
+ } else {
+ float_exp_avg = 0.0f;
+ float_exp_avg_sq = 0.0f;
+ }
+
+ //// quantize the first-order and second-order momentum
+ int wid = threadIdx.x / WARPSIZE;
+
+ // reduction within a warp
+ __shared__ float sharedFirstMaxVal[32];
+ __shared__ float sharedFirstMinVal[32];
+ __shared__ float sharedSecondMaxVal[32];
+ __shared__ float sharedSecondMinVal[32];
+ cg::thread_block_tile<32> warpTile =
+ cg::tiled_partition<32>(cg::this_thread_block());
+ float firstMaxVal = fabsf(float_exp_avg);
+ float firstMinVal = fabsf(float_exp_avg);
+ float secondMaxVal = fabsf(float_exp_avg_sq);
+ float secondMinVal = fabsf(float_exp_avg_sq);
+ // Special Handel
+ if (idx >= total_elements) {
+ firstMinVal = __int_as_float(0x7f7fffff);
+ secondMinVal = __int_as_float(0x7f7fffff);
+ }
+
+ for (int i = warpTile.size() / 2; i > 0; i /= 2) {
+ float reduceFirstMaxVal = warpTile.shfl_down(firstMaxVal, i);
+ float reduceFirstMinVal = warpTile.shfl_down(firstMinVal, i);
+ float reduceSecondMaxVal = warpTile.shfl_down(secondMaxVal, i);
+ float reduceSecondMinVal = warpTile.shfl_down(secondMinVal, i);
+ firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal));
+ firstMinVal = fmin(firstMinVal, fabsf(reduceFirstMinVal));
+ secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal));
+ secondMinVal = fmin(secondMinVal, fabsf(reduceSecondMinVal));
+ // printf("First Max: %f\n", reduceFirstMaxVal);
+ }
+ int lane = warpTile.thread_rank();
+ if (lane == 0) {
+ sharedFirstMaxVal[wid] = firstMaxVal;
+ sharedFirstMinVal[wid] = firstMinVal;
+ sharedSecondMaxVal[wid] = secondMaxVal;
+ sharedSecondMinVal[wid] = secondMinVal;
+ }
+
+ __syncthreads();
+
+ // reduction within a block
+ __shared__ float shared_absmax_exp_avg;
+ __shared__ float shared_absmin_exp_avg;
+ __shared__ float shared_absmax_exp_avg_sq;
+ __shared__ float shared_absmin_exp_avg_sq;
+ firstMaxVal =
+ (threadIdx.x < blockDim.x / warpSize) ? sharedFirstMaxVal[lane] : 0;
+ firstMinVal =
+ (threadIdx.x < blockDim.x / warpSize) ? sharedFirstMinVal[lane] : 1e9;
+ secondMaxVal =
+ (threadIdx.x < blockDim.x / warpSize) ? sharedSecondMaxVal[lane] : 0;
+ secondMinVal =
+ (threadIdx.x < blockDim.x / warpSize) ? sharedSecondMinVal[lane] : 1e9;
+ if (wid == 0) {
+ for (int offset = WARPSIZE / 2; offset > 0; offset /= 2) {
+ float reduceFirstMaxVal =
+ __shfl_down_sync(0xFFFFFFFF, firstMaxVal, offset);
+ float reduceFirstMinVal =
+ __shfl_down_sync(0xFFFFFFFF, firstMinVal, offset);
+ float reduceSecondMaxVal =
+ __shfl_down_sync(0xFFFFFFFF, secondMaxVal, offset);
+ float reduceSecondMinVal =
+ __shfl_down_sync(0xFFFFFFFF, secondMinVal, offset);
+ firstMaxVal = fmax(firstMaxVal, fabsf(reduceFirstMaxVal));
+ firstMinVal = fmin(firstMinVal, fabsf(reduceFirstMinVal));
+ secondMaxVal = fmax(secondMaxVal, fabsf(reduceSecondMaxVal));
+ secondMinVal = fmin(secondMinVal, fabsf(reduceSecondMinVal));
+ }
+ if (lane == 0) {
+ shared_absmax_exp_avg = firstMaxVal;
+ shared_absmin_exp_avg = firstMinVal;
+ shared_absmax_exp_avg_sq = secondMaxVal;
+ shared_absmin_exp_avg_sq = secondMinVal;
+ }
+ }
+
+ __syncthreads();
+
+ if (idx < total_elements) {
+ // float fp8MaxVal = fp8_dtype_max<__nv_fp8_e4m3>(exp_avg[idx]);
+ // scaling factor before expanding
+ float fp8MaxVal = 448;
+
+ // dynamic exponent quantization part
+ firstMaxVal = shared_absmax_exp_avg + QUANT_MIN_VAL;
+ firstMinVal = shared_absmin_exp_avg + QUANT_MIN_VAL;
+ secondMaxVal = shared_absmax_exp_avg_sq + QUANT_MIN_VAL;
+ secondMinVal = shared_absmin_exp_avg_sq + QUANT_MIN_VAL;
+
+ // calculate the ratio and make the scale to center
+ float firstRatio = firstMaxVal / firstMinVal;
+ float secondRatio = secondMaxVal / secondMinVal;
+ float firstSqrtMinMax = sqrt(firstMaxVal * firstMinVal);
+ float secondSqrtMinMax = sqrt(secondMaxVal * secondMinVal);
+
+ // printf("Max %f, Min %f, Origin %f \n", firstMaxVal, firstMinVal,
+ // float_exp_avg);
+
+ // since we use x^k expander, calculate the optimal expanding factor
+ float ratioUpperBound = fp8MaxVal * fp8MaxVal / 2;
+ float firstExp =
+ floor((log2f(ratioUpperBound) / log2f(firstRatio)) * expand_min) /
+ expand_min; // expand_min is set to 8 for example, then the firstExp is
+ // the multiple of 1/8
+ float secondExp =
+ floor((log2f(ratioUpperBound) / log2f(secondRatio)) * expand_min) /
+ expand_min;
+
+ int sign_exp_avg = 1 - 2 * signbit(float_exp_avg);
+ float_exp_avg =
+ sign_exp_avg * powf(fabsf(float_exp_avg) / firstSqrtMinMax, firstExp);
+ float_exp_avg_sq = powf(float_exp_avg_sq / secondSqrtMinMax, secondExp);
+
+ // correspondingly, change the scaling factor
+ float new_scale_exp_avg =
+ powf(firstMaxVal / firstSqrtMinMax, firstExp) / fp8MaxVal;
+ float new_scale_exp_avg_sq =
+ powf(secondMaxVal / secondSqrtMinMax, secondExp) / fp8MaxVal;
+
+ // quantize the optimizer states
+ __nv_fp8_e4m3 exp_avg_new =
+ static_cast<__nv_fp8_e4m3>(float_exp_avg / new_scale_exp_avg);
+ __nv_fp8_e4m3 exp_avg_sq_new =
+ static_cast<__nv_fp8_e4m3>(float_exp_avg_sq / new_scale_exp_avg_sq);
+ // __half exp_avg_new = static_cast<__half>(float_exp_avg /
+ // new_scale_exp_avg);
+ // __half exp_avg_sq_new = static_cast<__half>(float_exp_avg_sq /
+ // new_scale_exp_avg_sq);
+
+ // printf("idx: %d, float: %f, quantize: %f\n", idx, float_exp_avg,
+ // (float)exp_avg_new * new_scale_exp_avg);
+
+ // store the output
+ exp_avg[idx] = exp_avg_new;
+ exp_avg_sq[idx] = exp_avg_sq_new;
+ scale_exp_avg[scale_idx] = new_scale_exp_avg;
+ scale_exp_avg_sq[scale_idx] = new_scale_exp_avg_sq;
+ expand_exp_avg[scale_idx] = firstExp;
+ expand_exp_avg_sq[scale_idx] = secondExp;
+ sqrtminmax_exp_avg[scale_idx] = firstSqrtMinMax;
+ sqrtminmax_exp_avg_sq[scale_idx] = secondSqrtMinMax;
+ }
+}
+
+void FP8_AdamW_expand_cuda(torch::Tensor params, // parameter
+ torch::Tensor grads, // gradient
+ torch::Tensor exp_avg, // first order momentum
+ torch::Tensor scale_exp_avg,
+ torch::Tensor expand_exp_avg,
+ torch::Tensor sqrtminmax_exp_avg,
+ torch::Tensor exp_avg_sq, // second order momentum
+ torch::Tensor scale_exp_avg_sq,
+ torch::Tensor expand_exp_avg_sq,
+ torch::Tensor sqrtminmax_exp_avg_sq, float beta1,
+ float beta2, float lr, float wd, float eps, int step,
+ int qgroup_size,
+ int expand_min) { // other parameters
+
+ // CUDA Blocks
+ int total_elements = params.numel();
+ int total_scale_elements = scale_exp_avg.numel();
+ AT_ASSERTM(qgroup_size == 128,
+ "Only Support 128 per-group quantization currently");
+ const int block_dim = 128; // This should equal to the qgroup_size
+ int grid_dim = (total_elements + qgroup_size - 1) / block_dim;
+ AT_ASSERTM(grid_dim == scale_exp_avg.numel());
+ AT_ASSERTM(grid_dim == scale_exp_avg_sq.numel());
+ const dim3 blocks(grid_dim);
+
+ // Execution
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ at::kBFloat16, at::kHalf, params.scalar_type(), "fp8_adamw", ([&] {
+ fp8_adamw_cuda_expand_kernel<<>>(
+ params.data_ptr(), grads.data_ptr(),
+ (__nv_fp8_e4m3*)exp_avg.data_ptr(),
+ scale_exp_avg.data_ptr(), expand_exp_avg.data_ptr(),
+ sqrtminmax_exp_avg.data_ptr(),
+ (__nv_fp8_e4m3*)exp_avg_sq.data_ptr(),
+ scale_exp_avg_sq.data_ptr(),
+ expand_exp_avg_sq.data_ptr(),
+ sqrtminmax_exp_avg_sq.data_ptr(), beta1, beta2, lr, wd, eps,
+ step, qgroup_size, expand_min, total_elements,
+ total_scale_elements);
+ }));
+}
diff --git a/llava/model/coat/optimizer/kernels/include/fp8_adamw.h b/llava/model/coat/optimizer/kernels/include/fp8_adamw.h
new file mode 100644
index 0000000000000000000000000000000000000000..c04970da0dd564ea27e1e6203a38206d9506cac8
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/include/fp8_adamw.h
@@ -0,0 +1,14 @@
+#ifndef FP8_ADAMW
+#define FP8_ADAMW
+
+#include
+
+void FP8_AdamW(torch::Tensor params, // parameter
+ torch::Tensor grads, // gradient
+ torch::Tensor exp_avg, // first order momentum
+ torch::Tensor scale_exp_avg,
+ torch::Tensor exp_avg_sq, // second order momentum
+ torch::Tensor scale_exp_avg_sq, float beta1, float beta2,
+ float lr, float wd, float eps, int step, int qgroup_size);
+
+#endif // FP8_ADAMW
diff --git a/llava/model/coat/optimizer/kernels/include/fp8_adamw_expand.h b/llava/model/coat/optimizer/kernels/include/fp8_adamw_expand.h
new file mode 100644
index 0000000000000000000000000000000000000000..e1fb69825933154b1a867efaecd54cbdf100201b
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/include/fp8_adamw_expand.h
@@ -0,0 +1,18 @@
+#ifndef FP8_ADAMW_CONPAND
+#define FP8_ADAMW_CONPAND
+
+#include
+
+void FP8_AdamW_expand(torch::Tensor params, // parameter
+ torch::Tensor grads, // gradient
+ torch::Tensor exp_avg, // first order momentum
+ torch::Tensor scale_exp_avg, torch::Tensor expand_exp_avg,
+ torch::Tensor sqrtminmax_exp_avg,
+ torch::Tensor exp_avg_sq, // second order momentum
+ torch::Tensor scale_exp_avg_sq,
+ torch::Tensor expand_exp_avg_sq,
+ torch::Tensor sqrtminmax_exp_avg_sq, float beta1,
+ float beta2, float lr, float wd, float eps, int step,
+ int qgroup_size, int expand_min);
+
+#endif // FP8_ADAMW_CONPAND
diff --git a/llava/model/coat/optimizer/kernels/setup.py b/llava/model/coat/optimizer/kernels/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..8453efeaa472f3b094a2c157acbebf4c571eb09d
--- /dev/null
+++ b/llava/model/coat/optimizer/kernels/setup.py
@@ -0,0 +1,56 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+setup(
+ name="qoptim_cuda",
+ ext_modules=[
+ CUDAExtension(
+ name="qoptim_cuda",
+ sources=[
+ "fp8_adamw_cuda.cpp",
+ "fp8_adamw_cuda_kernel.cu",
+ "fp8_adamw_expand_cuda.cpp",
+ "fp8_adamw_expand_cuda_kernel.cu",
+ "bindings.cpp",
+ ],
+ # include_dirs=[
+ # 'include'
+ # ],
+ extra_compile_args={
+ "nvcc": [
+ "-O3",
+ "-std=c++17",
+ "-gencode=arch=compute_90,code=compute_90",
+ "-DTORCH_USE_CUDA_DSA",
+ "-U__CUDA_NO_HALF_OPERATORS__",
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
+ "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
+ "-U__CUDA_NO_HALF2_OPERATORS__",
+ ]
+ },
+ ),
+ ],
+ cmdclass={"build_ext": BuildExtension},
+)
diff --git a/llava/model/configuration_llava.py b/llava/model/configuration_llava.py
new file mode 100644
index 0000000000000000000000000000000000000000..b413354ab70581faf619a632e1934a1a956250d0
--- /dev/null
+++ b/llava/model/configuration_llava.py
@@ -0,0 +1,131 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Literal, Optional
+
+from pydantic import BaseModel, Field
+from transformers import PretrainedConfig
+
+
+class LlavaConfig(PretrainedConfig):
+ model_type = "llava"
+
+ def __init__(
+ self,
+ llm_cfg=None,
+ vision_tower_cfg=None,
+ speech_tower_cfg=None,
+ sound_tower_cfg=None,
+ mm_projector_cfg=None,
+ speech_mm_projector_cfg=None,
+ sound_mm_projector_cfg=None,
+ architectures=None,
+ resume_path=None,
+ hidden_size=None,
+ mm_hidden_size=None,
+ speech_hidden_size=None,
+ sound_hidden_size=None,
+ image_aspect_ratio=None,
+ num_video_frames=None,
+ fps=None,
+ mm_vision_select_layer=None,
+ mm_vision_select_feature=None,
+ mm_use_im_start_end=False,
+ mm_use_im_patch_token=False,
+ mm_projector_lr=None,
+ speech_mm_projector_lr=None,
+ sound_mm_projector_lr=None,
+ vision_tower_lr=None,
+ speech_tower_lr=None,
+ sound_tower_lr=None,
+ vision_resolution=None,
+ interpolate_mode=None,
+ s2=None,
+ dynamic_s2=None,
+ s2_scales=None,
+ s2_max_split_size=None,
+ s2_resize_output_to_scale_idx=0,
+ min_tiles: Optional[int] = 1,
+ max_tiles: Optional[int] = 12,
+ video_max_tiles: Optional[int] = 1,
+ num_time_tokens=None,
+ time_token_format=None,
+ image_encoder: str = '{"_target_": "llava.model.encoders.BasicImageEncoder"}',
+ video_encoder: str = '{"_target_": "llava.model.encoders.BasicVideoEncoder"}',
+ speech_encoder: str = '{"_target_": "llava.model.encoders.BasicSpeechEncoder"}',
+ sound_encoder: str = '{"_target_": "llava.model.encoders.BasicSoundEncoder"}',
+ **kwargs,
+ ):
+ super().__init__()
+ self.architectures = architectures
+ self.llm_cfg = llm_cfg
+ self.vision_tower_cfg = vision_tower_cfg
+ self.speech_tower_cfg = speech_tower_cfg
+ self.sound_tower_cfg = sound_tower_cfg
+ self.mm_projector_cfg = mm_projector_cfg
+ self.speech_mm_projector_cfg = speech_mm_projector_cfg
+ self.sound_mm_projector_cfg = sound_mm_projector_cfg
+ self.resume_path = resume_path
+
+ self.hidden_size = hidden_size
+ self.mm_hidden_size = mm_hidden_size
+ self.speech_hidden_size = speech_hidden_size
+ self.sound_hidden_size = sound_hidden_size
+ self.image_aspect_ratio = image_aspect_ratio
+ self.num_video_frames = num_video_frames
+ self.fps = fps
+ self.mm_vision_select_layer = mm_vision_select_layer
+ self.mm_vision_select_feature = mm_vision_select_feature
+ self.mm_use_im_start_end = mm_use_im_start_end
+ self.mm_use_im_patch_token = mm_use_im_patch_token
+ self.mm_projector_lr = mm_projector_lr
+ self.speech_mm_projector_lr = speech_mm_projector_lr
+ self.sound_mm_projector_lr = sound_mm_projector_lr
+ self.vision_tower_lr = vision_tower_lr
+ self.speech_tower_lr = speech_tower_lr
+ self.sound_tower_lr = sound_tower_lr
+ self.vision_resolution = vision_resolution
+ self.interpolate_mode = interpolate_mode
+ self.s2 = s2
+ self.dynamic_s2 = dynamic_s2
+ self.s2_scales = s2_scales
+ self.s2_max_split_size = s2_max_split_size
+ self.s2_resize_output_to_scale_idx = s2_resize_output_to_scale_idx
+ self.min_tiles = min_tiles
+ self.max_tiles = max_tiles
+ self.video_max_tiles = video_max_tiles
+ self.num_time_tokens = num_time_tokens
+ self.time_token_format = time_token_format
+
+ self.image_encoder = image_encoder
+ self.video_encoder = video_encoder
+ self.speech_encoder = speech_encoder
+ self.sound_encoder = sound_encoder
+
+
+class JsonSchemaResponseFormat(BaseModel):
+ schema_: str = Field(alias="schema")
+
+
+class ResponseFormat(BaseModel):
+ type: Literal["text", "json_object", "json_schema"]
+ json_schema: Optional[JsonSchemaResponseFormat] = None
diff --git a/llava/model/deprecate_consolidate.py b/llava/model/deprecate_consolidate.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebb9619df19e4e8791920b4c22d38763538454b1
--- /dev/null
+++ b/llava/model/deprecate_consolidate.py
@@ -0,0 +1,52 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Usage:
+python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
+"""
+import argparse
+
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from llava.model import *
+from llava.model.utils import auto_upgrade
+
+
+def consolidate_ckpt(src_path, dst_path):
+ print("Loading model")
+ auto_upgrade(src_path)
+ src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
+ src_model.save_pretrained(dst_path)
+ src_tokenizer.save_pretrained(dst_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--src", type=str, required=True)
+ parser.add_argument("--dst", type=str, required=True)
+
+ args = parser.parse_args()
+
+ consolidate_ckpt(args.src, args.dst)
diff --git a/llava/model/encoders/__init__.py b/llava/model/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5e80d23d6693e11f0779e342c9b6384d78b4c79
--- /dev/null
+++ b/llava/model/encoders/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from .image import *
+from .video import *
+from .speech import *
+from .sound import *
+
diff --git a/llava/model/encoders/__pycache__/__init__.cpython-310.pyc b/llava/model/encoders/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..51821273a8633724bacf886f82c5d3f6dccef0f2
Binary files /dev/null and b/llava/model/encoders/__pycache__/__init__.cpython-310.pyc differ
diff --git a/llava/model/encoders/__pycache__/__init__.cpython-311.pyc b/llava/model/encoders/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..91555cc8635606593bafd5cbfc61fff128718d6e
Binary files /dev/null and b/llava/model/encoders/__pycache__/__init__.cpython-311.pyc differ
diff --git a/llava/model/encoders/__pycache__/base.cpython-310.pyc b/llava/model/encoders/__pycache__/base.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d1a5e6a8761723bbebdd2cec99f1d26496e05ce0
Binary files /dev/null and b/llava/model/encoders/__pycache__/base.cpython-310.pyc differ
diff --git a/llava/model/encoders/__pycache__/base.cpython-311.pyc b/llava/model/encoders/__pycache__/base.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3e0467e667be5f975f22ccd0d080fe75c8ea560c
Binary files /dev/null and b/llava/model/encoders/__pycache__/base.cpython-311.pyc differ
diff --git a/llava/model/encoders/base.py b/llava/model/encoders/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c9230c954a0f711b063e9b9929cbe0e3d40b55a
--- /dev/null
+++ b/llava/model/encoders/base.py
@@ -0,0 +1,19 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from torch import nn
+
+__all__ = ["BaseEncoder"]
+
+
+class BaseEncoder(nn.Module):
+ def __init__(self, parent: nn.Module) -> None:
+ super().__init__()
+ self._parent = [parent]
+
+ @property
+ def parent(self) -> nn.Module:
+ return self._parent[0]
diff --git a/llava/model/encoders/image/__init__.py b/llava/model/encoders/image/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0df4b41c3ec1283ddeeb6024aa584e6cc3a9b4fb
--- /dev/null
+++ b/llava/model/encoders/image/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from .basic import *
diff --git a/llava/model/encoders/image/basic.py b/llava/model/encoders/image/basic.py
new file mode 100644
index 0000000000000000000000000000000000000000..6428e38f56ffe421edd25bf45f378b25849874f8
--- /dev/null
+++ b/llava/model/encoders/image/basic.py
@@ -0,0 +1,56 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from functools import partial
+from typing import Any, Dict, List, Optional
+
+import torch
+
+from llava.model.encoders.base import BaseEncoder
+
+__all__ = ["BasicImageEncoder"]
+
+
+class BasicImageEncoder(BaseEncoder):
+ def __init__(
+ self,
+ parent: torch.nn.Module,
+ start_tokens: Optional[str] = None,
+ end_tokens: Optional[str] = "\n",
+ ) -> None:
+ super().__init__(parent)
+ self.start_tokens = start_tokens
+ self.end_tokens = end_tokens
+
+ def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
+ if tokens is None:
+ return None
+ token_ids = self.parent.tokenizer(tokens).input_ids
+ token_ids = torch.tensor(token_ids, device=self.parent.device)
+ return self.parent.llm.model.embed_tokens(token_ids)
+
+ def _process_features(
+ self,
+ features: torch.Tensor,
+ start_token_embeds: Optional[torch.Tensor],
+ end_token_embeds: Optional[torch.Tensor],
+ ) -> torch.Tensor:
+ if start_token_embeds is not None:
+ features = torch.cat([start_token_embeds, features], dim=0)
+ if end_token_embeds is not None:
+ features = torch.cat([features, end_token_embeds], dim=0)
+ return features
+
+ def forward(self, images: List[torch.Tensor], config: Dict[str, Any], frame_times=None) -> List[torch.Tensor]:
+ images = torch.stack(images, dim=0)
+ features = self.parent.encode_images(images, block_sizes=config.get("block_sizes"))
+
+ process_features = partial(
+ self._process_features,
+ start_token_embeds=self.embed_tokens(self.start_tokens),
+ end_token_embeds=self.embed_tokens(self.end_tokens),
+ )
+ return [process_features(f) for f in features]
diff --git a/llava/model/encoders/sound/__init__.py b/llava/model/encoders/sound/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0df4b41c3ec1283ddeeb6024aa584e6cc3a9b4fb
--- /dev/null
+++ b/llava/model/encoders/sound/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from .basic import *
diff --git a/llava/model/encoders/sound/basic.py b/llava/model/encoders/sound/basic.py
new file mode 100644
index 0000000000000000000000000000000000000000..beb5ba5af1ba40fc9a9cba5b41e37673a6cdfeb2
--- /dev/null
+++ b/llava/model/encoders/sound/basic.py
@@ -0,0 +1,57 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from functools import partial
+from typing import Any, Dict, List, Optional
+
+import torch
+
+from llava.model.encoders.base import BaseEncoder
+
+__all__ = ["BasicSoundEncoder"]
+
+
+class BasicSoundEncoder(BaseEncoder):
+ def __init__(
+ self,
+ parent: torch.nn.Module,
+ start_tokens: Optional[str] = None,
+ end_tokens: Optional[str] = "\n",
+ ) -> None:
+ super().__init__(parent)
+ self.start_tokens = start_tokens
+ self.end_tokens = end_tokens
+
+ def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
+ if tokens is None:
+ return None
+ token_ids = self.parent.tokenizer(tokens).input_ids
+ token_ids = torch.tensor(token_ids, device=self.parent.device)
+ return self.parent.llm.model.embed_tokens(token_ids)
+
+ def _process_features(
+ self,
+ features: torch.Tensor,
+ start_token_embeds: Optional[torch.Tensor],
+ end_token_embeds: Optional[torch.Tensor],
+ ) -> torch.Tensor:
+ features = features.to(self.parent.device)
+ if start_token_embeds is not None:
+ features = torch.cat([start_token_embeds, features], dim=0)
+ if end_token_embeds is not None:
+ features = torch.cat([features, end_token_embeds], dim=0)
+ return features
+
+ def forward(self, sounds: List[torch.Tensor], config: Dict[str, Any], masks: Dict[str, Any]) -> List[torch.Tensor]:
+ sounds = torch.stack(sounds, dim=0)
+ masks = torch.stack(masks, dim=0)
+ features = self.parent.encode_sound(sounds, masks)
+ process_features = partial(
+ self._process_features,
+ start_token_embeds=self.embed_tokens(self.start_tokens),
+ end_token_embeds=self.embed_tokens(self.end_tokens),
+ )
+ return [process_features(f) for f in features]
diff --git a/llava/model/encoders/speech/__init__.py b/llava/model/encoders/speech/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0df4b41c3ec1283ddeeb6024aa584e6cc3a9b4fb
--- /dev/null
+++ b/llava/model/encoders/speech/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from .basic import *
diff --git a/llava/model/encoders/speech/basic.py b/llava/model/encoders/speech/basic.py
new file mode 100644
index 0000000000000000000000000000000000000000..68226501c5fcfc8d0cc6030ca71b324a7cb7f360
--- /dev/null
+++ b/llava/model/encoders/speech/basic.py
@@ -0,0 +1,55 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from functools import partial
+from typing import Any, Dict, List, Optional
+
+import torch
+
+from llava.model.encoders.base import BaseEncoder
+
+__all__ = ["BasicSpeechEncoder"]
+
+
+class BasicSpeechEncoder(BaseEncoder):
+ def __init__(
+ self,
+ parent: torch.nn.Module,
+ start_tokens: Optional[str] = None,
+ end_tokens: Optional[str] = "\n",
+ ) -> None:
+ super().__init__(parent)
+ self.start_tokens = start_tokens
+ self.end_tokens = end_tokens
+
+ def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
+ if tokens is None:
+ return None
+ token_ids = self.parent.tokenizer(tokens).input_ids
+ token_ids = torch.tensor(token_ids, device=self.parent.device)
+ return self.parent.llm.model.embed_tokens(token_ids)
+
+ def _process_features(
+ self,
+ features: torch.Tensor,
+ start_token_embeds: Optional[torch.Tensor],
+ end_token_embeds: Optional[torch.Tensor],
+ ) -> torch.Tensor:
+ if start_token_embeds is not None:
+ features = torch.cat([start_token_embeds, features], dim=0)
+ if end_token_embeds is not None:
+ features = torch.cat([features, end_token_embeds], dim=0)
+ return features
+
+ def forward(self, speeches: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]:
+ speeches = torch.stack(speeches, dim=0)
+ features = self.parent.encode_speech(speeches)
+ process_features = partial(
+ self._process_features,
+ start_token_embeds=self.embed_tokens(self.start_tokens),
+ end_token_embeds=self.embed_tokens(self.end_tokens),
+ )
+ return [process_features(f) for f in features]
diff --git a/llava/model/encoders/video/__init__.py b/llava/model/encoders/video/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b00796ffe81fd8d1cb50cd604a353e8c4e8199ad
--- /dev/null
+++ b/llava/model/encoders/video/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from .basic import *
+from .tsp import *
diff --git a/llava/model/encoders/video/basic.py b/llava/model/encoders/video/basic.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b84c6917874686d071ef60a5bd6f2cab0a7f9e4
--- /dev/null
+++ b/llava/model/encoders/video/basic.py
@@ -0,0 +1,59 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from functools import partial
+from typing import Any, Dict, List, Optional
+
+import torch
+
+from llava.model.encoders.base import BaseEncoder
+
+__all__ = ["BasicVideoEncoder"]
+
+
+class BasicVideoEncoder(BaseEncoder):
+ def __init__(
+ self,
+ parent: torch.nn.Module,
+ start_tokens: Optional[str] = None,
+ end_tokens: Optional[str] = "\n",
+ ) -> None:
+ super().__init__(parent)
+ self.start_tokens = start_tokens
+ self.end_tokens = end_tokens
+
+ def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
+ if tokens is None:
+ return None
+ token_ids = self.parent.tokenizer(tokens).input_ids
+ token_ids = torch.tensor(token_ids, device=self.parent.device)
+ return self.parent.llm.model.embed_tokens(token_ids)
+
+ def _process_features(
+ self,
+ features: torch.Tensor,
+ start_token_embeds: Optional[torch.Tensor],
+ end_token_embeds: Optional[torch.Tensor],
+ ) -> torch.Tensor:
+ if start_token_embeds is not None:
+ start_embeds = torch.stack([start_token_embeds] * features.shape[0], dim=0)
+ features = torch.cat([start_embeds, features], dim=1)
+ if end_token_embeds is not None:
+ end_embeds = torch.stack([end_token_embeds] * features.shape[0], dim=0)
+ features = torch.cat([features, end_embeds], dim=1)
+ return features.flatten(0, 1)
+
+ def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]:
+ num_frames = [video.shape[0] for video in videos]
+ images = torch.cat(videos, dim=0)
+ features = self.parent.encode_images(images)
+ features = torch.split(features, num_frames)
+ process_features = partial(
+ self._process_features,
+ start_token_embeds=self.embed_tokens(self.start_tokens),
+ end_token_embeds=self.embed_tokens(self.end_tokens),
+ )
+ return [process_features(f) for f in features]
diff --git a/llava/model/encoders/video/tsp.py b/llava/model/encoders/video/tsp.py
new file mode 100644
index 0000000000000000000000000000000000000000..efa06530b3aea80a994efb5bd3d5db70f46bd577
--- /dev/null
+++ b/llava/model/encoders/video/tsp.py
@@ -0,0 +1,70 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from functools import partial
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+
+from .basic import BasicVideoEncoder
+
+__all__ = ["TSPVideoEncoder"]
+
+
+def pool(x: torch.Tensor, size: int, dim: int) -> torch.Tensor:
+ return x.view(x.shape[:dim] + (-1, size) + x.shape[dim + 1 :]).mean(dim + 1)
+
+
+class TSPVideoEncoder(BasicVideoEncoder):
+ def __init__(
+ self,
+ parent: torch.nn.Module,
+ pool_sizes: List[Tuple[int, int, int]],
+ start_tokens: Optional[str] = None,
+ end_tokens: Optional[str] = "\n",
+ sep_tokens: Optional[str] = None,
+ ) -> None:
+ super().__init__(parent, start_tokens=start_tokens, end_tokens=end_tokens)
+ self.pool_sizes = pool_sizes
+ self.sep_tokens = sep_tokens
+
+ def _process_features(
+ self,
+ inputs: torch.Tensor,
+ start_token_embeds: Optional[torch.Tensor],
+ end_token_embeds: Optional[torch.Tensor],
+ sep_token_embeds: Optional[torch.Tensor],
+ ) -> torch.Tensor:
+ nt, ns = inputs.shape[:2]
+ nl = int(ns**0.5)
+ outputs = []
+ for pool_size in self.pool_sizes:
+ features = inputs.view(nt, nl, nl, -1)
+ for dim, p in enumerate(pool_size):
+ features = pool(features, p, dim=dim)
+ features = features.flatten(1, 2)
+ features = super()._process_features(
+ features,
+ start_token_embeds=start_token_embeds,
+ end_token_embeds=end_token_embeds,
+ )
+ if sep_token_embeds is not None:
+ features = torch.cat([features, sep_token_embeds], dim=0)
+ outputs.append(features)
+ return torch.cat(outputs, dim=0)
+
+ def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]:
+ num_frames = [video.shape[0] for video in videos]
+ images = torch.cat(videos, dim=0)
+ features = self.parent.encode_images(images)
+ features = torch.split(features, num_frames)
+ process_features = partial(
+ self._process_features,
+ start_token_embeds=self.embed_tokens(self.start_tokens),
+ end_token_embeds=self.embed_tokens(self.end_tokens),
+ sep_token_embeds=self.embed_tokens(self.sep_tokens),
+ )
+ return [process_features(f) for f in features]
diff --git a/llava/model/language_model/builder.py b/llava/model/language_model/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6936800c3c06666fc9d3d02f7424d0eb6409ea39
--- /dev/null
+++ b/llava/model/language_model/builder.py
@@ -0,0 +1,223 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+import math
+import os
+import os.path as osp
+import warnings
+from dataclasses import asdict
+from typing import Tuple
+
+import torch
+from huggingface_hub import file_exists, repo_exists
+from huggingface_hub.utils import HFValidationError
+from transformers import (
+ AutoConfig,
+ AutoModelForCausalLM,
+ AutoModelForVision2Seq,
+ AutoTokenizer,
+ PretrainedConfig,
+ PreTrainedModel,
+ PreTrainedTokenizer,
+)
+
+
+from llava.constants import MEDIA_TOKENS
+from llava.model.utils import packing
+from llava.utils.logging import logger
+from llava.utils.tokenizer import infer_stop_tokens
+
+
+def has_tokenizer(repo_id_or_path: str) -> bool:
+ # Check if the tokenizer is in a local directory
+ if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")):
+ return True
+
+ # Check if the tokenizer is in a Hugging Face Hub repo
+ try:
+ return repo_exists(repo_id_or_path) and file_exists(repo_id_or_path, "tokenizer_config.json")
+ except HFValidationError:
+ return False
+
+
+def context_length_extension(config):
+ orig_ctx_len = getattr(config, "max_position_embeddings", None)
+ model_max_length = getattr(config, "model_max_length", None)
+ if orig_ctx_len and model_max_length > orig_ctx_len:
+ print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}")
+ scaling_factor = float(math.ceil(model_max_length / orig_ctx_len))
+ config.rope_scaling = {"type": "linear", "factor": scaling_factor}
+ return config
+
+
+def build_llm_and_tokenizer(
+ model_name_or_path: str,
+ config: PretrainedConfig,
+ attn_implementation=None,
+ model_max_length=None,
+ *args,
+ **kwargs,
+) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
+ # print(model_name_or_path)
+ llm_cfg = AutoConfig.from_pretrained(model_name_or_path)
+ llm_cfg._attn_implementation = attn_implementation
+ llm_cfg.model_max_length = model_max_length
+ if model_max_length is not None:
+ context_length_extension(llm_cfg)
+
+ # Quantization related
+ quantization_restore_from_checkpoint = False
+ if kwargs.get("quantize_model_class") is not None:
+ assert kwargs.get("model_args") is not None
+ quantize_model_class = kwargs.pop("quantize_model_class", None)
+ model_args = kwargs.pop("model_args", None)
+
+ if quantize_model_class == "QLlamaForCausalLM": # TODO: Also change the name of this class
+ from .qllama import QLlamaConfig
+
+ llm_cfg.architectures = "QLlamaForCausalLM"
+ _attn_implementation = llm_cfg._attn_implementation
+ llm_cfg = QLlamaConfig(**llm_cfg.to_dict())
+ llm_cfg._attn_implementation = _attn_implementation
+ elif quantize_model_class == "QMemLlamaForCausalLM": # TODO: Also change the name of this class
+ from .qmemllama import QMemLlamaConfig
+
+ llm_cfg.architectures = "QMemLlamaForCausalLM"
+ llm_cfg = QMemLlamaConfig(**llm_cfg.to_dict())
+ elif quantize_model_class == "FP8LinearQwen2ForCausalLM":
+ from .configuration_quantize import QuantizationConfig
+ from .fp8linearqwen2 import FP8LinearQwen2Config
+
+ llm_cfg.architectures = "FP8LinearQwen2ForCausalLM"
+ coat_fp8_args = QuantizationConfig(**asdict(model_args))
+
+ # Remove the quantization args from llm_cfg and make it a independent config
+ model_args_dict = asdict(model_args)
+ for key in asdict(coat_fp8_args).keys():
+ model_args_dict.pop(key, None)
+
+ llm_cfg.coat_fp8_args = asdict(coat_fp8_args)
+ _attn_implementation = llm_cfg._attn_implementation
+
+ llm_cfg = FP8LinearQwen2Config(**llm_cfg.to_dict())
+ llm_cfg._attn_implementation = _attn_implementation
+
+ elif quantize_model_class == "FP8ActivationQwen2ForCausalLM":
+ from ..coat.activation.models._fp8_quantization_config import QuantizationConfig
+ from .fp8activationqwen2 import FP8ActivationQwen2Config
+
+ quantization_restore_from_checkpoint = True
+
+ llm_cfg.architectures = "FP8ActivationQwen2ForCausalLM"
+ coat_fp8_args = QuantizationConfig(**asdict(model_args))
+
+ # Remove the quantization args from llm_cfg and make it a independent config
+ model_args_dict = asdict(model_args)
+ for key in asdict(coat_fp8_args).keys():
+ model_args_dict.pop(key, None)
+
+ llm_cfg.coat_fp8_args = asdict(coat_fp8_args)
+ _attn_implementation = llm_cfg._attn_implementation
+
+ llm_cfg = FP8ActivationQwen2Config(**llm_cfg.to_dict())
+ llm_cfg._attn_implementation = _attn_implementation
+
+ elif quantize_model_class == "FP8ActivationResidualQwen2ForCausalLM":
+ from ..coat.activation.models._fp8_quantization_config import QuantizationConfig
+ from .fp8activationresidualqwen2 import FP8ActivationResidualQwen2Config
+
+ quantization_restore_from_checkpoint = True
+
+ llm_cfg.architectures = "FP8ActivationResidualQwen2ForCausalLM"
+ coat_fp8_args = QuantizationConfig(**asdict(model_args))
+
+ # Remove the quantization args from llm_cfg and make it a independent config
+ model_args_dict = asdict(model_args)
+ for key in asdict(coat_fp8_args).keys():
+ model_args_dict.pop(key, None)
+
+ llm_cfg.coat_fp8_args = asdict(coat_fp8_args)
+ _attn_implementation = llm_cfg._attn_implementation
+
+ llm_cfg = FP8ActivationResidualQwen2Config(**llm_cfg.to_dict())
+ llm_cfg._attn_implementation = _attn_implementation
+ else:
+ raise ValueError(f"{quantize_model_class} is not supported quantize_model_class.")
+
+ kwargs.pop("quantize_model_class", None)
+
+ if quantize_model_class in [
+ "FP8LinearQwen2ForCausalLM",
+ "FP8ActivationQwen2ForCausalLM",
+ "FP8ActivationResidualQwen2ForCausalLM",
+ ]: # Remove the quantization args from llm_cfg and make it a independent config
+ llm_cfg.update(model_args_dict)
+ else:
+ llm_cfg.update(asdict(model_args))
+ # print(model_args)
+
+ if quantization_restore_from_checkpoint:
+ fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None)
+
+ llm = AutoModelForCausalLM.from_pretrained(
+ model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
+ )
+
+ else:
+ llm = AutoModelForCausalLM.from_pretrained(
+ model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
+ )
+ packing.patch(llm)
+
+ # Locate the tokenizer.
+ llm_path = model_name_or_path
+ if not has_tokenizer(llm_path):
+ llm_path = osp.join(llm_path, "llm")
+ if not has_tokenizer(llm_path):
+ raise ValueError(f"Cannot find tokenizer in {llm_path}.")
+
+ tokenizer = AutoTokenizer.from_pretrained(llm_path, padding_side="right", use_fast=True, legacy=False)
+ if model_max_length is not None:
+ tokenizer.model_max_length = model_max_length
+
+ # Load chat template if specified.
+ if getattr(config, "chat_template", None) is not None:
+ logger.info(f"Using chat template: {config.chat_template}")
+ fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja")
+ with open(fpath) as fd:
+ chat_template = fd.read()
+ tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "")
+
+ # Set stop tokens for the tokenizer
+ tokenizer.stop_tokens = infer_stop_tokens(tokenizer)
+ tokenizer.stop_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.stop_tokens)
+
+ # Add media tokens to the tokenizer
+ tokenizer.media_tokens = MEDIA_TOKENS
+ tokenizer.media_token_ids = {}
+ for name, token in MEDIA_TOKENS.items():
+ tokenizer.add_tokens([token], special_tokens=True)
+ tokenizer.media_token_ids[name] = tokenizer.convert_tokens_to_ids(token)
+
+ # TODO(ligeng): is this necessary for llava?
+ config.hidden_size = llm.config.hidden_size
+ return llm, tokenizer
diff --git a/llava/model/language_model/chat_templates/mistral.jinja b/llava/model/language_model/chat_templates/mistral.jinja
new file mode 100644
index 0000000000000000000000000000000000000000..774073e68db8dc9ffc65daab8d948c0079235a5d
--- /dev/null
+++ b/llava/model/language_model/chat_templates/mistral.jinja
@@ -0,0 +1,11 @@
+{{ bos_token }}
+
+{% for message in messages if message['content'] is not none %}
+ {% if message['role'] == 'system' %}
+ {{ message['content'] | trim + '\n\n' }}
+ {% elif message['role'] == 'user' %}
+ {{ '[INST] ' + message['content'] | trim + ' [/INST]' }}
+ {% elif message['role'] == 'assistant' %}
+ {{ ' ' + message['content'] | trim + eos_token }}
+ {% endif %}
+{% endfor %}
diff --git a/llava/model/language_model/chat_templates/qwen2.jinja b/llava/model/language_model/chat_templates/qwen2.jinja
new file mode 100644
index 0000000000000000000000000000000000000000..d3b0fd8c09e426c3a8c6e29edde3e56586c961bb
--- /dev/null
+++ b/llava/model/language_model/chat_templates/qwen2.jinja
@@ -0,0 +1,11 @@
+{% if messages[0]['role'] != 'system' %}
+ {{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}
+{% endif %}
+
+{% for message in messages if message['content'] is not none %}
+ {{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}
+{% endfor %}
+
+{% if add_generation_prompt %}
+ {{ '<|im_start|>assistant\n' }}
+{% endif %}
diff --git a/llava/model/language_model/configuration_quantize.py b/llava/model/language_model/configuration_quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b226923e1ab06e6f868455e72dd8430386c8ac8
--- /dev/null
+++ b/llava/model/language_model/configuration_quantize.py
@@ -0,0 +1,77 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+from dataclasses import dataclass
+
+from transformers import PretrainedConfig
+
+
+@dataclass
+class QuantizationConfig:
+ quantize_model: str = "false"
+ symm: bool = True
+ epsilon: float = 1e-10
+ fabit: str = "E4M3"
+ fwbit: str = "E4M3"
+ bobit: str = "E5M2"
+ row_blocksize: int = -1
+ col_blocksize: int = -1
+ qchoice: str = "none"
+ pad_to_multiple_of: int = 0
+
+ def __init__(
+ self,
+ quantize_model,
+ symm,
+ epsilon,
+ fabit,
+ fwbit,
+ bobit,
+ row_blocksize,
+ col_blocksize,
+ qchoice,
+ pad_to_multiple_of,
+ **kwargs,
+ ):
+ super().__init__()
+ self.quantize_model = quantize_model
+ self.symm = symm
+ self.epsilon = epsilon
+ self.fabit = fabit
+ self.fwbit = fwbit
+ self.bobit = bobit
+ self.row_blocksize = row_blocksize
+ self.col_blocksize = col_blocksize
+ self.qchoice = qchoice
+ self.pad_to_multiple_of = pad_to_multiple_of
+
+
+# class QuantizationConfig(PretrainedConfig):
+# def __init__(
+# self,
+# quantize_model="false",
+# symm=True,
+# epsilon=1e-10,
+# fabit="E4M3",
+# fwbit="E4M3",
+# bobit="E5M2",
+# row_blocksize=-1,
+# col_blocksize=-1,
+# qchoice="none",
+# pad_to_multiple_of=0,
+# **kwargs,
+# ):
+# super().__init__()
+# self.quantize_model = quantize_model
+# self.symm = symm
+# self.epsilon = epsilon
+# self.fabit = fabit
+# self.fwbit = fwbit
+# self.bobit = bobit
+# self.row_blocksize = row_blocksize
+# self.col_blocksize = col_blocksize
+# self.qchoice = qchoice
+# self.pad_to_multiple_of = pad_to_multiple_of
diff --git a/llava/model/language_model/fp8_qwen2_convert_from_hf.py b/llava/model/language_model/fp8_qwen2_convert_from_hf.py
new file mode 100644
index 0000000000000000000000000000000000000000..71a7f9f602491851c21cf7a7702a7174b414bfee
--- /dev/null
+++ b/llava/model/language_model/fp8_qwen2_convert_from_hf.py
@@ -0,0 +1,73 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import argparse
+import os
+from dataclasses import asdict, dataclass, field
+from typing import Optional
+
+import torch
+import transformers
+from transformers import AutoConfig, AutoModelForCausalLM
+
+from ..coat.activation.models._fp8_quantization_config import QuantizationConfig
+from .fp8activationqwen2 import FP8ActivationQwen2Config, make_state_dict_compatible
+
+
+@dataclass
+class ConvertArguments:
+ model_name: str = field(metadata={"help": "The model name or path to download the LLaMA model"})
+ save_path: str = field(metadata={"help": "The path where the converted model weights will be saved"})
+ cache_dir: str = field(default=None, metadata={"help": "Directory to cache the model"})
+
+
+def download_and_convert_qwen2(convert_args: ConvertArguments, quantization_args: QuantizationConfig):
+ """
+ Downloads a LLaMA model, converts its weights using `make_state_dict_compatible`,
+ and saves the converted model.
+
+ Args:
+ model_name (str): The model name or path to download the LLaMA model.
+ save_path (str): The path where the converted model weights will be saved.
+ cache_dir (Optional[str]): Directory to cache the model. Defaults to None.
+
+ Returns:
+ None
+ """
+ model_name = convert_args.model_name
+ save_path = convert_args.save_path
+ cache_dir = convert_args.cache_dir
+
+ # Step 1: Download the original LLaMA model
+ model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
+
+ # Step 2: Initialize the model configuration for FP8 or other custom config
+ config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
+
+ # Step 3: Apply make_state_dict_compatible to convert weights
+ compatible_state_dict = make_state_dict_compatible(model.state_dict())
+
+ # Step 4: Create a new model instance with compatible configuration
+ fp8_config = FP8ActivationQwen2Config(**config.to_dict())
+ fp8_config.coat_fp8_args = asdict(quantization_args)
+ fp8_config._name_or_path = save_path
+
+ converted_model = AutoModelForCausalLM.from_config(fp8_config, torch_dtype=torch.bfloat16)
+ converted_model.load_state_dict(compatible_state_dict)
+
+ # Step 5: Save the converted model and configuration using save_pretrained
+ os.makedirs(save_path, exist_ok=True)
+ converted_model.save_pretrained(save_path)
+ print(f"Converted model saved at {save_path}")
+
+
+if __name__ == "__main__":
+ # Parse command-line arguments
+ parser = transformers.HfArgumentParser((ConvertArguments, QuantizationConfig)) # NOTE: FP8
+ convert_args, quantization_args = parser.parse_args_into_dataclasses()
+
+ # Call the function with parsed arguments
+ download_and_convert_qwen2(convert_args, quantization_args)
diff --git a/llava/model/language_model/fp8activationqwen2.py b/llava/model/language_model/fp8activationqwen2.py
new file mode 100644
index 0000000000000000000000000000000000000000..34bde452da53042dc55b39ae204e541e1eb1f920
--- /dev/null
+++ b/llava/model/language_model/fp8activationqwen2.py
@@ -0,0 +1,1722 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Qwen2 model."""
+
+import math
+import os
+from dataclasses import asdict, dataclass, field
+from fnmatch import fnmatch
+from functools import lru_cache
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache, StaticCache
+from transformers.modeling_attn_mask_utils import AttentionMaskConverter
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
+from transformers.models.qwen2.modeling_qwen2 import (
+ Qwen2Attention,
+ Qwen2DecoderLayer,
+ Qwen2FlashAttention2,
+ Qwen2ForCausalLM,
+ Qwen2MLP,
+ Qwen2Model,
+ Qwen2PreTrainedModel,
+ Qwen2RMSNorm,
+ Qwen2RotaryEmbedding,
+ Qwen2SdpaAttention,
+ apply_rotary_pos_emb,
+ repeat_kv,
+ rotate_half,
+)
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+ replace_return_docstrings,
+)
+
+# FP8 related
+from ..coat.activation.models._fp8_quantization_config import QuantizationConfig
+from ..coat.activation.models._fp8_weightcache import FP8CacheWeightModule
+from ..coat.activation.models._fp8manager import FP8Manager
+from ..coat.activation.real_quantization import (
+ Coat_quantize_bgn,
+ Coat_quantize_end,
+ fp8_add_Ifp_Ifp_Ofp_Og16,
+ fp8_add_Ifp_Ifp_Ofp_Opt,
+ fp8_division,
+ fp8_division_transpose,
+ fp8_gelu_backward,
+ fp8_gelu_forward,
+ fp8_layernorm_noparam_backward,
+ fp8_layernorm_noparam_forward,
+ fp8_linear_backward,
+ fp8_linear_forward,
+ fp8_mul_backward,
+ fp8_mul_forward,
+ fp8_quantize,
+ fp8_quantize_pertensor,
+ fp8_quantize_pertensor_transpose,
+ fp8_rmsnorm_backward,
+ fp8_rmsnorm_forward,
+ fp8_silu_backward,
+ fp8_silu_forward,
+ fp8_transpose,
+)
+from ..liger.cross_entropy import LigerForCausalLMLoss
+from ..qlinear_te import QLinearTE
+
+# from .configuration_quantize import QuantizationConfig
+
+if is_flash_attn_2_available():
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
+
+logger = logging.get_logger(__name__)
+
+
+class FP8ActivationQwen2Config(Qwen2Config):
+ model_type = "fp8activation_qwen2"
+
+ def __init__(
+ self,
+ coat_fp8_args=None,
+ vocab_size=151936,
+ hidden_size=4096,
+ intermediate_size=22016,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=32,
+ hidden_act="silu",
+ max_position_embeddings=32768,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ use_sliding_window=False,
+ sliding_window=4096,
+ max_window_layers=28,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_size,
+ hidden_size,
+ intermediate_size,
+ num_hidden_layers,
+ num_attention_heads,
+ num_key_value_heads,
+ hidden_act,
+ max_position_embeddings,
+ initializer_range,
+ rms_norm_eps,
+ use_cache,
+ tie_word_embeddings,
+ rope_theta,
+ rope_scaling,
+ use_sliding_window,
+ sliding_window,
+ max_window_layers,
+ attention_dropout,
+ **kwargs,
+ )
+
+ self.coat_fp8_args = coat_fp8_args
+
+
+class FP8ActivationQwen2BeforeAttentionResidual(FP8CacheWeightModule):
+ """
+ This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 1 * Linear layers
+ """
+
+ def __init__(self, config: FP8ActivationQwen2Config, qargs: QuantizationConfig, layer_idx: Optional[int] = None):
+ super().__init__(config, qargs, layer_idx)
+
+ self.qargs = qargs
+ self.fwobits = {
+ "fabit": self.qargs.fabit,
+ "fwbit": self.qargs.fwbit,
+ "fobit": self.qargs.fobit,
+ "babit": self.qargs.babit,
+ "bwbit": self.qargs.bwbit,
+ "bobit": self.qargs.bobit,
+ }
+
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
+
+ def forward(self, re_x, x, s, rmsnorm_weight):
+ if self.training:
+ if self.qargs.weight_memory_efficient:
+ # Prepare
+ with torch.no_grad():
+ if FP8Manager.is_first_microbatch:
+ # Directly use the corresponding weight
+ weight1, _, weight1_s = self.prepare_weight(
+ self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch
+ )
+ weight2, _, weight2_s = self.prepare_weight(
+ self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch
+ )
+ weight3, _, weight3_s = self.prepare_weight(
+ self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch
+ )
+ else:
+ weight1_s = self.prepare_weight(self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch)
+ weight2_s = self.prepare_weight(self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch)
+ weight3_s = self.prepare_weight(self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch)
+
+ weight1, weight2, weight3 = None, None, None
+ return _FP8ActivationQwen2BeforeAttentionResidual.apply(
+ re_x,
+ x,
+ s,
+ self.q_proj.weight,
+ weight1,
+ None,
+ weight1_s,
+ self.q_proj.bias,
+ self.k_proj.weight,
+ weight2,
+ None,
+ weight2_s,
+ self.k_proj.bias,
+ self.v_proj.weight,
+ weight3,
+ None,
+ weight3_s,
+ self.v_proj.bias,
+ rmsnorm_weight,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ # Prepare
+ with torch.no_grad():
+ weight1, weight1_t, weight1_s = self.prepare_weight(
+ self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch
+ )
+ weight2, weight2_t, weight2_s = self.prepare_weight(
+ self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch
+ )
+ weight3, weight3_t, weight3_s = self.prepare_weight(
+ self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch
+ )
+ return _FP8ActivationQwen2BeforeAttentionResidual.apply(
+ re_x,
+ x,
+ s,
+ self.q_proj.weight,
+ weight1,
+ weight1_t,
+ weight1_s,
+ self.k_proj.weight,
+ weight2,
+ weight2_t,
+ weight2_s,
+ self.v_proj.weight,
+ weight3,
+ weight3_t,
+ weight3_s,
+ rmsnorm_weight,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ raise NotImplementedError("This should be implemented in the future")
+ return re_x, self.att_proj(self.attn_norm(re_x))
+
+
+class _FP8ActivationQwen2BeforeAttentionResidual(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ re_x,
+ in_x,
+ in_s,
+ weight1_origin,
+ weight1,
+ weight1_t,
+ weight1_s,
+ weight1_bias,
+ weight2_origin,
+ weight2,
+ weight2_t,
+ weight2_s,
+ weight2_bias,
+ weight3_origin,
+ weight3,
+ weight3_t,
+ weight3_s,
+ weight3_bias,
+ rmsnorm_weight,
+ group_size,
+ fwobits,
+ layer_id,
+ config,
+ qargs,
+ eps=1e-5,
+ ):
+ # for autograd
+ if fwobits["fabit"] == "E4M3":
+ # in_x = in_x.to(torch.float8_e4m3fn)
+ in_x = in_x.view(torch.float8_e4m3fn)
+ else:
+ raise ValueError("fabit should be E4M3")
+
+ # LayerNorm
+ ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward(
+ in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True
+ )
+
+ # Linear Layer QKV Projection
+ if qargs.weight_memory_efficient:
+ assert (
+ weight1_t is None and weight2_t is None and weight3_t is None
+ ) # we should not pass W^T to here if weight memory efficient
+ if weight1 is None:
+ assert weight2 is None and weight3 is None
+ weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s)
+ weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s)
+ weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s)
+
+ fc1_x = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, False, group_size, bias=weight1_bias) # query states
+ fc2_x = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, False, group_size, bias=weight2_bias) # key states
+ fc3_x = fp8_linear_forward(ln_x, ln_s, weight3, weight3_s, False, group_size, bias=weight3_bias) # value states
+
+ # ==================== save for backward ====================
+ ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s)
+ if qargs.weight_memory_efficient:
+ assert weight1_t is None and weight2_t is None and weight3_t is None
+ ctx.weight = weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s
+ else:
+ ctx.weight = weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s
+ ctx.bias = weight1_bias, weight2_bias, weight3_bias
+
+ ctx.group_size = group_size
+ ctx.ln_utils = ln_utils
+ ctx.utils = fwobits, layer_id, config, qargs
+
+ return re_x, fc1_x, fc2_x, fc3_x
+
+ @staticmethod
+ def backward(ctx, fp_grad, query_g, key_g, value_g):
+ in_x, in_s, ln_x_t, ln_s = ctx.saved_tensors
+ weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s = ctx.weight
+ weight1_bias, weight2_bias, weight3_bias = ctx.bias
+
+ group_size = ctx.group_size
+ rms_weight, rstd, num_warps = ctx.ln_utils
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ # ==================== Begin backward ====================
+ # Gradient of Bias TODO: make this better
+ if weight1_bias is not None and weight2_bias is not None and weight3_bias is not None:
+ att_q_bg = query_g.reshape(-1, query_g.shape[-1]).sum(0)
+ att_k_bg = key_g.reshape(-1, key_g.shape[-1]).sum(0)
+ att_v_bg = value_g.reshape(-1, value_g.shape[-1]).sum(0)
+ else:
+ att_q_bg = None
+ att_k_bg = None
+ att_v_bg = None
+
+ # Quantize the RoPE and FlashAttention Output. grad_input and grad_weight requires different data layout.
+ query_g, query_gs, query_g_t = fp8_quantize_pertensor_transpose(
+ query_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False
+ )
+ key_g, key_gs, key_g_t = fp8_quantize_pertensor_transpose(
+ key_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False
+ )
+ value_g, value_gs, value_g_t = fp8_quantize_pertensor_transpose(
+ value_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False
+ )
+
+ # Linear Layer QKV Projection
+ if qargs.weight_memory_efficient:
+ weight1_t, weight1_s = fp8_division_transpose(
+ weight1_t, qargs.group_size, fwobits["fwbit"], weight1_s, only_transposed=True
+ )
+ weight2_t, weight2_s = fp8_division_transpose(
+ weight2_t, qargs.group_size, fwobits["fwbit"], weight2_s, only_transposed=True
+ )
+ weight3_t, weight3_s = fp8_division_transpose(
+ weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True
+ )
+
+ fc1_g1, att_q_wg = fp8_linear_backward(
+ ln_x_t, ln_s, query_g, query_gs, query_g_t, weight1_t, weight1_s, group_size
+ )
+ fc1_g2, att_k_wg = fp8_linear_backward(ln_x_t, ln_s, key_g, key_gs, key_g_t, weight2_t, weight2_s, group_size)
+ fc1_g3, att_v_wg = fp8_linear_backward(
+ ln_x_t, ln_s, value_g, value_gs, value_g_t, weight3_t, weight3_s, group_size
+ )
+
+ fc1_g = fc1_g1 + fc1_g2 + fc1_g3
+
+ # LayerNorm
+ in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc1_g, rms_weight, rstd, group_size, num_warps)
+
+ # Add the gradient together, and prepare the input of the next layer.
+ re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt(
+ fp_grad, in_g, group_size, fwobits["babit"], stochastic=False
+ )
+
+ # for autograd. forward's data type should be the same of backward tensor. this will not change the actual binary representation.
+ in_g = in_g.view(torch.float8_e4m3fn)
+
+ # Although the next operator is a linear layer in MLPResidual module, we return in_sg_g16 to make the size compatible with the forward. Otherwise it will not pass autograd.
+ return (
+ re_g,
+ in_g,
+ in_sg_g16,
+ att_q_wg,
+ None,
+ None,
+ None,
+ att_q_bg,
+ att_k_wg,
+ None,
+ None,
+ None,
+ att_k_bg,
+ att_v_wg,
+ None,
+ None,
+ None,
+ att_v_bg,
+ rms_weight_grad,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+class FP8ActivationQwen2AfterAttentionResidual(FP8CacheWeightModule):
+ """
+ This is a typical transformer attention module that contains (1) Residual (2) 1 * Linear layers
+ """
+
+ def __init__(self, config: FP8ActivationQwen2Config, qargs: QuantizationConfig, layer_id):
+ super().__init__(config, qargs, layer_id)
+
+ self.qargs = qargs
+ self.fwobits = {
+ "fabit": self.qargs.fabit,
+ "fwbit": self.qargs.fwbit,
+ "fobit": self.qargs.fobit,
+ "babit": self.qargs.babit,
+ "bwbit": self.qargs.bwbit,
+ "bobit": self.qargs.bobit,
+ }
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+ def forward(self, re_x, in_x):
+ if self.training:
+ if self.qargs.weight_memory_efficient:
+ # prepare for the weight
+ with torch.no_grad():
+ if FP8Manager.is_first_microbatch:
+ weight4, _, weight4_s = self.prepare_weight(
+ self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch
+ )
+ else:
+ weight4_s = self.prepare_weight(self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch)
+ weight1 = None
+
+ return _FP8ActivationQwen2AfterAttentionResidual.apply(
+ re_x,
+ in_x,
+ self.o_proj.weight,
+ weight4,
+ None,
+ weight4_s,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ # prepare for the weight
+ with torch.no_grad():
+ weight4, weight4_t, weight4_s = self.prepare_weight(
+ self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch
+ )
+
+ return _FP8ActivationQwen2AfterAttentionResidual.apply(
+ re_x,
+ in_x,
+ self.o_proj.weight,
+ weight4,
+ weight4_t,
+ weight4_s,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ return re_x + self.attn_out(in_x), None, None
+
+
+class _FP8ActivationQwen2AfterAttentionResidual(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx, re_x, flash_x, weight4_origin, weight4, weight4_t, weight4_s, group_size, fwobits, layer_id, config, qargs
+ ):
+ time_bench = os.getenv("TIME_BENCH")
+
+ if time_bench:
+ start_1 = torch.cuda.Event(enable_timing=True)
+ start_1.record()
+
+ # Quantize the FlashAttention Output
+ flash_qx, flash_s, _ = fp8_quantize_pertensor(
+ flash_x, group_size, fwobits["fabit"]
+ ) # Modified to make it memory efficient
+
+ if time_bench:
+ end_1 = torch.cuda.Event(enable_timing=True)
+ end_1.record()
+ start_2 = torch.cuda.Event(enable_timing=True)
+ start_2.record()
+
+ # # Attention Projection Linear Layer
+ if qargs.weight_memory_efficient:
+ assert weight4_t is None
+ if weight4 is None: # the second batch
+ weight4, weight4_s = fp8_division(weight4_origin, qargs.group_size, fwobits["fwbit"], weight4_s)
+
+ if time_bench:
+ end_2 = torch.cuda.Event(enable_timing=True)
+ end_2.record()
+ start_3 = torch.cuda.Event(enable_timing=True)
+ start_3.record()
+
+ fc4_x = fp8_linear_forward(flash_qx, flash_s, weight4, weight4_s, False, group_size) #
+
+ if time_bench:
+ end_3 = torch.cuda.Event(enable_timing=True)
+ end_3.record()
+ start_4 = torch.cuda.Event(enable_timing=True)
+ start_4.record()
+
+ # import IPython
+ # IPython.embed()
+ # Add the activations together
+ fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc4_x, flash_qx.dtype, group_size)
+
+ if time_bench:
+ end_4 = torch.cuda.Event(enable_timing=True)
+ end_4.record()
+
+ torch.cuda.synchronize()
+ if int(os.environ.get("LOCAL_RANK")) == 0:
+ print(
+ f"[AfterAt] Part 1: {start_1.elapsed_time(end_1):.6f} ms | "
+ f" Part 2: {start_2.elapsed_time(end_2):.6f} ms | "
+ f" Part 3: {start_3.elapsed_time(end_3):.6f} ms | "
+ f" Part 4: {start_4.elapsed_time(end_4):.6f} ms | "
+ f" Input shape: {re_x.shape}"
+ )
+
+ # ==================== save for backward ====================
+ ctx.save_for_backward(flash_x, flash_s)
+ if qargs.weight_memory_efficient:
+ assert weight4_t is None
+ ctx.weight = weight4_origin, weight4_s
+ else:
+ ctx.weight = weight4_t, weight4_s
+ ctx.group_size = group_size
+ ctx.fwobits = fwobits
+ ctx.utils = fwobits, layer_id, config, qargs
+
+ # For autograd
+ out_x = out_x.view(torch.float8_e4m3fn)
+
+ return fp_x, out_x, out_s
+
+ @staticmethod
+ def backward(ctx, fp_grad, out_g, out_gs):
+ flash_x, flash_s = ctx.saved_tensors
+ weight4_t, weight4_s = ctx.weight
+ group_size = ctx.group_size
+ fwobits = ctx.fwobits
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ # for autograd
+ if fwobits["babit"] == "E5M2":
+ # out_g = out_g.to(torch.float8_e5m2)
+ out_g = out_g.view(torch.float8_e5m2)
+ else:
+ raise ValueError("babit should be E5M2")
+ out_gs_max = out_gs.max()
+
+ # ==================== Begin backward ====================
+ # Output Projection
+ out_g_t = fp8_transpose(out_g, transpose_output_2d=True)
+
+ # We do not save an extra flash_x to save the memory usage
+ flash_x_t, flash_s = fp8_division_transpose(
+ flash_x, group_size, fwobits["fabit"], flash_s, stochastic=False, only_transposed=True
+ )
+
+ if qargs.weight_memory_efficient:
+ weight4_t, weight4_s = fp8_division_transpose(
+ weight4_t, qargs.group_size, fwobits["fwbit"], weight4_s, only_transposed=True
+ )
+ fc4_g, attn_out_wg = fp8_linear_backward(
+ flash_x_t, flash_s, out_g, out_gs_max, out_g_t, weight4_t, weight4_s, group_size
+ )
+
+ return fp_grad, fc4_g, attn_out_wg, None, None, None, None, None, None, None, None
+
+
+class FP8ActivationQwen2MLPResidual(FP8CacheWeightModule):
+ """
+ This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 2 / 3 * Linear layers
+ (4) GELU / Silu Activation
+ """
+
+ def __init__(self, config: FP8ActivationQwen2Config, qargs: QuantizationConfig, layer_id, hidden_size: int):
+ super().__init__(config, qargs, layer_id)
+
+ self.qargs = qargs
+ self.fwobits = {
+ "fabit": self.qargs.fabit,
+ "fwbit": self.qargs.fwbit,
+ "fobit": self.qargs.fobit,
+ "babit": self.qargs.babit,
+ "bwbit": self.qargs.bwbit,
+ "bobit": self.qargs.bobit,
+ }
+
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.training = True
+
+ # below is only used when training = False
+ assert config.hidden_act == "silu", "We only support silu activation currently"
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, re_x, x, s, rmsnorm_weight):
+ if self.training:
+ if self.qargs.weight_memory_efficient: # prepare for the weight
+ with torch.no_grad():
+ if FP8Manager.is_first_microbatch:
+ # Directly use the corresponding weight
+ weight1, _, weight1_s = self.prepare_weight(
+ self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch
+ )
+ weight2, _, weight2_s = self.prepare_weight(
+ self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch
+ )
+ weight3, _, weight3_s = self.prepare_weight(
+ self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch
+ )
+ else:
+ weight1_s = self.prepare_weight(
+ self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch
+ )
+ weight2_s = self.prepare_weight(self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch)
+ weight3_s = self.prepare_weight(
+ self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch
+ )
+
+ weight1, weight2, weight3 = None, None, None
+ return _FP8ActivationQwen2MLPResidual.apply(
+ re_x,
+ x,
+ s,
+ self.gate_proj.weight,
+ weight1,
+ None,
+ weight1_s,
+ self.up_proj.weight,
+ weight2,
+ None,
+ weight2_s,
+ self.down_proj.weight,
+ weight3,
+ None,
+ weight3_s,
+ rmsnorm_weight,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ # prepare for the weight
+ with torch.no_grad():
+ weight1, weight1_t, weight1_s = self.prepare_weight(
+ self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch
+ )
+ weight2, weight2_t, weight2_s = self.prepare_weight(
+ self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch
+ )
+ weight3, weight3_t, weight3_s = self.prepare_weight(
+ self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch
+ )
+
+ return _FP8ActivationQwen2MLPResidual.apply(
+ re_x,
+ x,
+ s,
+ self.gate_proj.weight,
+ weight1,
+ weight1_t,
+ weight1_s,
+ self.up_proj.weight,
+ weight2,
+ weight2_t,
+ weight2_s,
+ self.down_proj.weight,
+ weight3,
+ weight3_t,
+ weight3_s,
+ rmsnorm_weight,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ raise NotImplementedError("Need TODO")
+ og_x = re_x
+ re_x = self.ff_norm(re_x)
+ re_x = self.ff_proj(re_x)
+ re_x = self.act(re_x)
+ re_x = self.ff_out(re_x)
+ re_x = og_x + re_x
+ return re_x, None, None
+
+
+class _FP8ActivationQwen2MLPResidual(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ re_x,
+ in_x,
+ in_s,
+ weight1_origin,
+ weight1,
+ weight1_t,
+ weight1_s,
+ weight2_origin,
+ weight2,
+ weight2_t,
+ weight2_s,
+ weight3_origin,
+ weight3,
+ weight3_t,
+ weight3_s,
+ rmsnorm_weight,
+ group_size,
+ fwobits,
+ layer_id,
+ config,
+ qargs,
+ eps=1e-5,
+ ):
+ # For autograd
+ if fwobits["fabit"] == "E4M3":
+ # in_x = in_x.to(torch.float8_e4m3fn)
+ in_x = in_x.view(torch.float8_e4m3fn)
+ else:
+ raise ValueError("fabit should be E4M3")
+
+ # LayerNorm
+ ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward(
+ in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True
+ )
+
+ # Linear Layer of Up Projection and Gate Projection. They are fused as one linear layer.
+ if qargs.weight_memory_efficient:
+ assert weight1_t is None and weight2_t is None and weight3_t is None # memory efficient
+ if weight1 is None:
+ assert weight2 is None and weight3 is None
+ weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s)
+ weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s)
+ weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s)
+
+ gate_x, gate_s = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, True, group_size) # Gate Proj
+ up_x, up_s = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, True, group_size) # Up Proj
+
+ # silu Activation
+ silu_x, silu_s = fp8_silu_forward(gate_x, gate_s, group_size)
+
+ # Element-wise Multiplication
+ mul_x, mul_s, mul_x_t = fp8_mul_forward(silu_x, silu_s, up_x, up_s, group_size, transpose_output_2d=True)
+
+ # Output Projection
+ if weight3 is None: # memory efficient
+ weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s)
+ fc3_x = fp8_linear_forward(mul_x, mul_s, weight3, weight3_s, False, group_size)
+
+ # Add the activation together
+ fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc3_x, mul_x.dtype, group_size)
+
+ # ==================== save for backward ====================
+ ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s)
+
+ ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s)
+ if (
+ qargs.weight_memory_efficient
+ ): # Weight_1/2_origin will not be saved twice, so it will be more memory efficient.
+ assert weight1_t is None and weight2_t is None and weight3_t is None
+ ctx.weight = (weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s)
+ else: # Weight1/2_t is different from the origin weight, so saving it will consumes additional memory footprint.
+ ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s)
+
+ ctx.group_size = group_size
+ ctx.ln_utils = ln_utils
+ ctx.utils = fwobits, layer_id, config, qargs
+
+ out_x = out_x.view(torch.float8_e4m3fn)
+
+ return fp_x, out_x, out_s
+
+ @staticmethod
+ def backward(ctx, fp_grad, out_g, out_gs):
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s = ctx.saved_tensors
+
+ (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s) = ctx.weight
+ group_size = ctx.group_size
+ rms_weight, rstd, num_warps = ctx.ln_utils
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ # For autograd
+ if fwobits["babit"] == "E5M2":
+ # out_g = out_g.to(torch.float8_e5m2)
+ out_g = out_g.view(torch.float8_e5m2)
+ else:
+ raise ValueError("babit should be E5M2")
+ out_gs_max = out_gs.max()
+
+ # ==================== Begin backward ====================
+ # Output Projection
+ out_gs = out_gs.max()
+ out_g_t = fp8_transpose(out_g, transpose_output_2d=True)
+
+ # Element-wise Multiplication gradient checkpointing
+ # silu gradient checkpointing
+ silu_x, silu_s = fp8_silu_forward(gate_x, gate_s, group_size)
+ mul_x, mul_s, mul_x_t = fp8_mul_forward(silu_x, silu_s, up_x, up_s, group_size, transpose_output_2d=True)
+
+ if qargs.weight_memory_efficient:
+ weight3_t, weight3_s = fp8_division_transpose(
+ weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True
+ )
+ fc3_g, weight3_grad = fp8_linear_backward(
+ mul_x_t, mul_s, out_g, out_gs_max, out_g_t, weight3_t, weight3_s, group_size
+ )
+
+ # [MEM TEST]
+ del out_g, out_g_t, weight3_t
+
+ # Element-wise Multiplication, 1 means gate, 2 means up
+ mul_g1, (mul_g2, mul_gs2, mul_g2_t) = fp8_mul_backward(
+ silu_x, silu_s, up_x, up_s, fc3_g, group_size, fwobits["babit"], output_quantized_transpose=True
+ )
+
+ # Silu activation
+ silu_g, silu_gs, silu_g_t = fp8_silu_backward(
+ gate_x, gate_s, mul_g1, group_size, fwobits["babit"], output_quantized_transpose=True
+ )
+
+ # Linear Layer of Up and Gate Projection
+ if qargs.weight_memory_efficient:
+ weight1_t, weight1_s = fp8_division_transpose(
+ weight1_t, group_size, fwobits["fwbit"], weight1_s, only_transposed=True
+ )
+ weight2_t, weight2_s = fp8_division_transpose(
+ weight2_t, group_size, fwobits["fwbit"], weight2_s, only_transposed=True
+ )
+
+ # Gate Proj
+ fc1_g, weight1_grad = fp8_linear_backward(
+ ln_x_t, ln_s, silu_g, silu_gs, silu_g_t, weight1_t, weight1_s, group_size
+ )
+ fc2_g, weight2_grad = fp8_linear_backward(
+ ln_x_t, ln_s, mul_g2, mul_gs2, mul_g2_t, weight2_t, weight2_s, group_size
+ )
+
+ fc_g = fc1_g + fc2_g
+
+ # layerNorm
+ in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc_g, rms_weight, rstd, group_size, num_warps)
+
+ # Add the gradient together
+ re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt(
+ fp_grad, in_g, group_size, fwobits["babit"], stochastic=False
+ )
+
+ in_g = in_g.view(torch.float8_e4m3fn)
+
+ return (
+ re_g,
+ in_g,
+ in_sg_g16,
+ weight1_grad,
+ None,
+ None,
+ None,
+ weight2_grad,
+ None,
+ None,
+ None,
+ weight3_grad,
+ None,
+ None,
+ None,
+ rms_weight_grad,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+class FP8ActivationQwen2AttentionWithoutLinear(nn.Module):
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
+ and "Generating Long Sequences with Sparse Transformers".
+ """
+
+ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.is_causal = True
+ self.attention_dropout = config.attention_dropout
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ self.rotary_emb = Qwen2RotaryEmbedding(config=self.config)
+
+ def forward(
+ self,
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = query_states.size()
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class FP8ActivationQwen2FlashAttention2WithoutLinear(FP8ActivationQwen2AttentionWithoutLinear):
+ """
+ Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
+ as the weights of the module stays untouched. The only required change would be on the forward pass
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
+ config.max_window_layers layers.
+ """
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ):
+ bsz, q_len, _ = query_states.size()
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
+ kv_seq_len = key_states.shape[-2] + cache_position[0]
+ if (
+ getattr(self.config, "sliding_window", None) is not None
+ and kv_seq_len > self.config.sliding_window
+ and cache_has_contents
+ ):
+ slicing_tokens = 1 - self.config.sliding_window
+
+ past_key = past_key_value[self.layer_idx][0]
+ past_value = past_key_value[self.layer_idx][1]
+
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
+ f" {past_key.shape}"
+ )
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, slicing_tokens:]
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
+
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ if (
+ self.config.use_sliding_window
+ and getattr(self.config, "sliding_window", None) is not None
+ and self.layer_idx >= self.config.max_window_layers
+ ):
+ sliding_window = self.config.sliding_window
+ else:
+ sliding_window = None
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=sliding_window,
+ is_causal=self.is_causal,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class FP8ActivationQwen2SdpaAttentionWithoutLinear(FP8ActivationQwen2AttentionWithoutLinear):
+ """
+ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from Qwen2Attention.forward
+ def forward(
+ self,
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ query_states=query_states,
+ key_states=key_states,
+ value_states=value_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ bsz, q_len, _ = query_states.size()
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and attention_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+
+ return attn_output, None, past_key_value
+
+
+FP8LINEARQWEN2_ATTENTION_CLASSES = {
+ "eager": FP8ActivationQwen2AttentionWithoutLinear,
+ "flash_attention_2": FP8ActivationQwen2FlashAttention2WithoutLinear,
+ "sdpa": FP8ActivationQwen2SdpaAttentionWithoutLinear,
+}
+
+
+class FP8ActivationQwen2DecoderLayer(nn.Module):
+ def __init__(self, config: FP8ActivationQwen2Config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ if config.sliding_window and config._attn_implementation != "flash_attention_2":
+ logger.warning_once(
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
+ "unexpected results may be encountered."
+ )
+ self.self_attn = FP8LINEARQWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
+
+ self.qargs = QuantizationConfig(**config.coat_fp8_args)
+ self.BeforeAttention = FP8ActivationQwen2BeforeAttentionResidual(config, self.qargs, layer_idx)
+ self.AfterAttention = FP8ActivationQwen2AfterAttentionResidual(config, self.qargs, layer_idx)
+ self.MLPResidual = FP8ActivationQwen2MLPResidual(config, self.qargs, layer_idx, self.hidden_size)
+
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ quant_hidden_states: torch.Tensor,
+ scale_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+
+ time_bench = os.getenv("TIME_BENCH")
+
+ if time_bench:
+ start_1 = torch.cuda.Event(enable_timing=True)
+ start_1.record()
+
+ # Coat: The residual, LayerNorm, and the Q/K/V Projection Linear Layer
+ residual, query_states, key_states, value_states = self.BeforeAttention(
+ hidden_states, quant_hidden_states, scale_hidden_states, self.input_layernorm.weight
+ )
+
+ if time_bench:
+ end_1 = torch.cuda.Event(enable_timing=True)
+ end_1.record()
+ start_2 = torch.cuda.Event(enable_timing=True)
+ start_2.record()
+
+ # Self Attention without any linear layer
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ query_states=query_states,
+ key_states=key_states,
+ value_states=value_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ if time_bench:
+ end_2 = torch.cuda.Event(enable_timing=True)
+ end_2.record()
+ start_3 = torch.cuda.Event(enable_timing=True)
+ start_3.record()
+
+ # Coat: The Output Projection Linear Layer and Residual
+ hidden_states, quant_hidden_states, scale_hidden_states = self.AfterAttention(residual, hidden_states)
+
+ if time_bench:
+ end_3 = torch.cuda.Event(enable_timing=True)
+ end_3.record()
+ start_4 = torch.cuda.Event(enable_timing=True)
+ start_4.record()
+
+ # Residual Connection, LayerNorm, and the whole MLP module
+ hidden_states, quant_hidden_states, scale_hidden_states = self.MLPResidual(
+ hidden_states, quant_hidden_states, scale_hidden_states, self.post_attention_layernorm.weight
+ )
+
+ if time_bench:
+ end_4 = torch.cuda.Event(enable_timing=True)
+ end_4.record()
+
+ torch.cuda.synchronize()
+ if int(os.environ.get("LOCAL_RANK")) == 0:
+ print(
+ f"[Forward] Part 1: {start_1.elapsed_time(end_1):.6f} ms | "
+ f" Part 2: {start_2.elapsed_time(end_2):.6f} ms | "
+ f" Part 3: {start_3.elapsed_time(end_3):.6f} ms | "
+ f" Part 4: {start_4.elapsed_time(end_4):.6f} ms | "
+ f" Input shape: {hidden_states.shape}"
+ )
+
+ outputs = ((hidden_states, quant_hidden_states, scale_hidden_states),)
+
+ # if int(os.environ.get("LOCAL_RANK")) == 0:
+ # import IPython
+ # IPython.embed()
+ # else:
+ # import time
+ # time.sleep(1000)
+ # if output_attentions:
+ # outputs += (self_attn_weights,)
+
+ # if use_cache:
+ # outputs += (present_key_value,)
+
+ return outputs
+
+
+class FP8ActivationQwen2PreTrainedModel(Qwen2PreTrainedModel):
+ config_class = FP8ActivationQwen2Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["FP8ActivationQwen2DecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_cache_class = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class FP8ActivationQwen2Model(FP8ActivationQwen2PreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
+
+ Args:
+ config: Qwen2Config
+ """
+
+ def __init__(self, config: FP8ActivationQwen2Config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [FP8ActivationQwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self._attn_implementation = config._attn_implementation
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Qwen2RotaryEmbedding(config=config)
+
+ # Quantize
+ self.qargs = QuantizationConfig(**config.coat_fp8_args)
+ self.quantize_input_before_block = Coat_quantize_bgn(self.qargs)
+ self.quantize_output_after_block = Coat_quantize_end(self.qargs)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ self.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ # import warnings
+
+ # # 将所有 UserWarning 提升为异常
+ # warnings.simplefilter("error", UserWarning)
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ # Prepare the input for Coat decoderlayer
+ hidden_states, quant_hidden_states, scale_hidden_states = self.quantize_input_before_block(hidden_states)
+
+ for layer_idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ # if int(os.environ.get("LOCAL_RANK")) == 0:
+ # print(f"FP8 Layer Idx: {layer_idx}, Input Shape: {hidden_states.shape}, Memory Allocated: {torch.cuda.memory_allocated() // 1024 ** 2} MB | Peak: {torch.cuda.max_memory_allocated() // 1024 ** 2} MB")
+
+ # NOTE: We explicitly force the LLM do not use Gradient Checkpointing
+ # exit(0)
+ # if self.gradient_checkpointing and self.training:
+ # layer_outputs = self._gradient_checkpointing_func(
+ # decoder_layer.__call__,
+ # hidden_states,
+ # quant_hidden_states,
+ # scale_hidden_states,
+ # causal_mask,
+ # position_ids,
+ # past_key_values,
+ # output_attentions,
+ # use_cache,
+ # cache_position,
+ # position_embeddings,
+ # )
+ # else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ quant_hidden_states,
+ scale_hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ hidden_states, quant_hidden_states, scale_hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ # Summarize the output of the Decoder Layer
+ hidden_states = self.quantize_output_after_block(hidden_states, quant_hidden_states, scale_hidden_states)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
+ _update_causal_mask = Qwen2Model._update_causal_mask
+
+
+class FP8ActivationQwen2ForCausalLM(FP8ActivationQwen2PreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = FP8ActivationQwen2Model(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @property
+ @lru_cache
+ def loss_function(self):
+ return LigerForCausalLMLoss
+
+ forward = Qwen2ForCausalLM.forward
+
+
+AutoConfig.register("fp8activation_qwen2", FP8ActivationQwen2Config)
+AutoModel.register(FP8ActivationQwen2Config, FP8ActivationQwen2Model)
+AutoModelForCausalLM.register(FP8ActivationQwen2Config, FP8ActivationQwen2ForCausalLM)
+
+
+def make_state_dict_compatible(state_dict: dict[str, torch.Tensor]):
+ compatible_state_dict = {}
+
+ for key, value in state_dict.items():
+ if fnmatch(key, "*self_attn.q_proj*"):
+ new_key = key.replace("self_attn.q_proj", "BeforeAttention.q_proj")
+ elif fnmatch(key, "*self_attn.k_proj*"):
+ new_key = key.replace("self_attn.k_proj", "BeforeAttention.k_proj")
+ elif fnmatch(key, "*self_attn.v_proj*"):
+ new_key = key.replace("self_attn.v_proj", "BeforeAttention.v_proj")
+ elif fnmatch(key, "*self_attn.o_proj*"):
+ new_key = key.replace("self_attn.o_proj", "AfterAttention.o_proj")
+
+ elif fnmatch(key, "*mlp.gate_proj*"):
+ new_key = key.replace("mlp.gate_proj", "MLPResidual.gate_proj")
+ elif fnmatch(key, "*mlp.up_proj*"):
+ new_key = key.replace("mlp.up_proj", "MLPResidual.up_proj")
+ elif fnmatch(key, "*mlp.down_proj*"):
+ new_key = key.replace("mlp.down_proj", "MLPResidual.down_proj")
+
+ else:
+ new_key = key
+
+ compatible_state_dict[new_key] = value
+
+ return compatible_state_dict
diff --git a/llava/model/language_model/fp8activationresidualqwen2.py b/llava/model/language_model/fp8activationresidualqwen2.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bccf84346e0e2b59ed6dc56ca1a6799bc11c5c5
--- /dev/null
+++ b/llava/model/language_model/fp8activationresidualqwen2.py
@@ -0,0 +1,1586 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Qwen2 model."""
+
+import math
+import os
+from dataclasses import asdict, dataclass, field
+from fnmatch import fnmatch
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache, StaticCache
+from transformers.modeling_attn_mask_utils import AttentionMaskConverter
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
+from transformers.models.qwen2.modeling_qwen2 import (
+ Qwen2Attention,
+ Qwen2DecoderLayer,
+ Qwen2FlashAttention2,
+ Qwen2ForCausalLM,
+ Qwen2MLP,
+ Qwen2Model,
+ Qwen2PreTrainedModel,
+ Qwen2RMSNorm,
+ Qwen2RotaryEmbedding,
+ Qwen2SdpaAttention,
+ apply_rotary_pos_emb,
+ repeat_kv,
+ rotate_half,
+)
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+ replace_return_docstrings,
+)
+
+# FP8 related
+from ..coat.activation.models._fp8_quantization_config import QuantizationConfig
+from ..coat.activation.models._fp8_weightcache import FP8CacheWeightModule
+from ..coat.activation.models._fp8manager import FP8Manager
+from ..coat.activation.real_quantization import (
+ Coat_quantize_bgn,
+ Coat_quantize_end,
+ fp8_add_Ifp_Ifp_Ofp_Og16,
+ fp8_add_Ifp_Ifp_Ofp_Opt,
+ fp8_division,
+ fp8_division_transpose,
+ fp8_gelu_backward,
+ fp8_gelu_forward,
+ fp8_layernorm_noparam_backward,
+ fp8_layernorm_noparam_forward,
+ fp8_linear_backward,
+ fp8_linear_forward,
+ fp8_mul_backward,
+ fp8_mul_forward,
+ fp8_quantize,
+ fp8_quantize_pertensor,
+ fp8_quantize_pertensor_transpose,
+ fp8_rmsnorm_backward,
+ fp8_rmsnorm_forward,
+ fp8_silu_backward,
+ fp8_silu_forward,
+ fp8_transpose,
+)
+from ..qlinear_te import QLinearTE
+from .configuration_quantize import QuantizationConfig
+
+if is_flash_attn_2_available():
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
+
+logger = logging.get_logger(__name__)
+
+
+class FP8ActivationResidualQwen2Config(Qwen2Config):
+ model_type = "fp8activationresidual_qwen2"
+
+ def __init__(
+ self,
+ coat_fp8_args=None,
+ vocab_size=151936,
+ hidden_size=4096,
+ intermediate_size=22016,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=32,
+ hidden_act="silu",
+ max_position_embeddings=32768,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ use_sliding_window=False,
+ sliding_window=4096,
+ max_window_layers=28,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_size,
+ hidden_size,
+ intermediate_size,
+ num_hidden_layers,
+ num_attention_heads,
+ num_key_value_heads,
+ hidden_act,
+ max_position_embeddings,
+ initializer_range,
+ rms_norm_eps,
+ use_cache,
+ tie_word_embeddings,
+ rope_theta,
+ rope_scaling,
+ use_sliding_window,
+ sliding_window,
+ max_window_layers,
+ attention_dropout,
+ **kwargs,
+ )
+
+ self.coat_fp8_args = coat_fp8_args
+
+
+class FP8ActivationResidualQwen2BeforeAttentionResidual(FP8CacheWeightModule):
+ """
+ This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 1 * Linear layers
+ """
+
+ def __init__(
+ self, config: FP8ActivationResidualQwen2Config, qargs: QuantizationConfig, layer_idx: Optional[int] = None
+ ):
+ super().__init__(config, qargs, layer_idx)
+
+ self.qargs = qargs
+ self.fwobits = {
+ "fabit": self.qargs.fabit,
+ "fwbit": self.qargs.fwbit,
+ "fobit": self.qargs.fobit,
+ "babit": self.qargs.babit,
+ "bwbit": self.qargs.bwbit,
+ "bobit": self.qargs.bobit,
+ }
+
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
+
+ def forward(self, x, s, rmsnorm_weight):
+ if self.training:
+ if self.qargs.weight_memory_efficient:
+ # Prepare
+ with torch.no_grad():
+ weight1_s = self.prepare_weight(self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch)
+ weight2_s = self.prepare_weight(self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch)
+ weight3_s = self.prepare_weight(self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch)
+ return _FP8ActivationResidualQwen2BeforeAttentionResidual.apply(
+ x,
+ s,
+ self.q_proj.weight,
+ None,
+ None,
+ weight1_s,
+ self.q_proj.bias,
+ self.k_proj.weight,
+ None,
+ None,
+ weight2_s,
+ self.k_proj.bias,
+ self.v_proj.weight,
+ None,
+ None,
+ weight3_s,
+ self.v_proj.bias,
+ rmsnorm_weight,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ # Prepare
+ with torch.no_grad():
+ weight1, weight1_t, weight1_s = self.prepare_weight(
+ self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch
+ )
+ weight2, weight2_t, weight2_s = self.prepare_weight(
+ self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch
+ )
+ weight3, weight3_t, weight3_s = self.prepare_weight(
+ self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch
+ )
+ return _FP8ActivationResidualQwen2BeforeAttentionResidual.apply(
+ x,
+ s,
+ self.q_proj.weight,
+ weight1,
+ weight1_t,
+ weight1_s,
+ self.k_proj.weight,
+ weight2,
+ weight2_t,
+ weight2_s,
+ self.v_proj.weight,
+ weight3,
+ weight3_t,
+ weight3_s,
+ rmsnorm_weight,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ raise NotImplementedError("This should be implemented in the future")
+ return re_x, self.att_proj(self.attn_norm(re_x))
+
+
+class _FP8ActivationResidualQwen2BeforeAttentionResidual(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ in_x,
+ in_s,
+ weight1_origin,
+ weight1,
+ weight1_t,
+ weight1_s,
+ weight1_bias,
+ weight2_origin,
+ weight2,
+ weight2_t,
+ weight2_s,
+ weight2_bias,
+ weight3_origin,
+ weight3,
+ weight3_t,
+ weight3_s,
+ weight3_bias,
+ rmsnorm_weight,
+ group_size,
+ fwobits,
+ layer_id,
+ config,
+ qargs,
+ eps=1e-5,
+ ):
+ # for autograd
+ if fwobits["fabit"] == "E4M3":
+ # in_x = in_x.to(torch.float8_e4m3fn)
+ in_x = in_x.view(torch.float8_e4m3fn)
+ else:
+ raise ValueError("fabit should be E4M3")
+
+ # LayerNorm
+ ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward(
+ in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True
+ )
+
+ # Linear Layer QKV Projection
+ if qargs.weight_memory_efficient:
+ assert weight1 is None # memory efficient
+ weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s)
+ weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s)
+ weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s)
+
+ fc1_x = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, False, group_size, bias=weight1_bias) # query states
+ fc2_x = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, False, group_size, bias=weight2_bias) # key states
+ fc3_x = fp8_linear_forward(ln_x, ln_s, weight3, weight3_s, False, group_size, bias=weight3_bias) # value states
+
+ # ==================== save for backward ====================
+ ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s)
+ if qargs.weight_memory_efficient:
+ assert weight1_t is None and weight2_t is None and weight3_t is None
+ ctx.weight = weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s
+ else:
+ ctx.weight = weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s
+ ctx.bias = weight1_bias, weight2_bias, weight3_bias
+
+ ctx.group_size = group_size
+ ctx.ln_utils = ln_utils
+ ctx.utils = fwobits, layer_id, config, qargs
+
+ return in_x, in_s, fc1_x, fc2_x, fc3_x
+
+ @staticmethod
+ def backward(ctx, q_grad, s_grad, query_g, key_g, value_g):
+ in_x, in_s, ln_x_t, ln_s = ctx.saved_tensors
+ weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s = ctx.weight
+ weight1_bias, weight2_bias, weight3_bias = ctx.bias
+
+ group_size = ctx.group_size
+ rms_weight, rstd, num_warps = ctx.ln_utils
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ # ==================== Begin backward ====================
+ # Gradient of Bias TODO: make this better
+ if weight1_bias is not None and weight2_bias is not None and weight3_bias is not None:
+ att_q_bg = query_g.reshape(-1, query_g.shape[-1]).sum(0)
+ att_k_bg = key_g.reshape(-1, key_g.shape[-1]).sum(0)
+ att_v_bg = value_g.reshape(-1, value_g.shape[-1]).sum(0)
+ else:
+ att_q_bg = None
+ att_k_bg = None
+ att_v_bg = None
+
+ # Quantize the RoPE and FlashAttention Output. grad_input and grad_weight requires different data layout.
+ query_g, query_gs, query_g_t = fp8_quantize_pertensor_transpose(
+ query_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False
+ )
+ key_g, key_gs, key_g_t = fp8_quantize_pertensor_transpose(
+ key_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False
+ )
+ value_g, value_gs, value_g_t = fp8_quantize_pertensor_transpose(
+ value_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False
+ )
+
+ # Linear Layer QKV Projection
+ if qargs.weight_memory_efficient:
+ weight1_t, weight1_s = fp8_division_transpose(
+ weight1_t, qargs.group_size, fwobits["fwbit"], weight1_s, only_transposed=True
+ )
+ weight2_t, weight2_s = fp8_division_transpose(
+ weight2_t, qargs.group_size, fwobits["fwbit"], weight2_s, only_transposed=True
+ )
+ weight3_t, weight3_s = fp8_division_transpose(
+ weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True
+ )
+
+ fc1_g1, att_q_wg = fp8_linear_backward(
+ ln_x_t, ln_s, query_g, query_gs, query_g_t, weight1_t, weight1_s, group_size
+ )
+ fc1_g2, att_k_wg = fp8_linear_backward(ln_x_t, ln_s, key_g, key_gs, key_g_t, weight2_t, weight2_s, group_size)
+ fc1_g3, att_v_wg = fp8_linear_backward(
+ ln_x_t, ln_s, value_g, value_gs, value_g_t, weight3_t, weight3_s, group_size
+ )
+
+ fc1_g = fc1_g1 + fc1_g2 + fc1_g3
+
+ # LayerNorm
+ in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc1_g, rms_weight, rstd, group_size, num_warps)
+
+ # Add the gradient together, and prepare the input of the next layer.
+ in_g, in_sg, in_sg_g16 = fp8_add_Ig16_Ifp_Opt(
+ q_grad, s_grad, in_g, group_size, fwobits["babit"], stochastic=False
+ )
+
+ # for autograd. forward's data type should be the same of backward tensor. this will not change the actual binary representation.
+ in_g = in_g.view(torch.float8_e4m3fn)
+
+ # Although the next operator is a linear layer in MLPResidual module, we return in_sg_g16 to make the size compatible with the forward. Otherwise it will not pass autograd.
+ return (
+ in_g,
+ in_sg_g16,
+ att_q_wg,
+ None,
+ None,
+ None,
+ att_q_bg,
+ att_k_wg,
+ None,
+ None,
+ None,
+ att_k_bg,
+ att_v_wg,
+ None,
+ None,
+ None,
+ att_v_bg,
+ rms_weight_grad,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+class FP8ActivationResidualQwen2AfterAttentionResidual(FP8CacheWeightModule):
+ """
+ This is a typical transformer attention module that contains (1) Residual (2) 1 * Linear layers
+ """
+
+ def __init__(self, config: FP8ActivationResidualQwen2Config, qargs: QuantizationConfig, layer_id):
+ super().__init__(config, qargs, layer_id)
+
+ self.qargs = qargs
+ self.fwobits = {
+ "fabit": self.qargs.fabit,
+ "fwbit": self.qargs.fwbit,
+ "fobit": self.qargs.fobit,
+ "babit": self.qargs.babit,
+ "bwbit": self.qargs.bwbit,
+ "bobit": self.qargs.bobit,
+ }
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+ def forward(self, re_qx, re_sx, in_x):
+ if self.training:
+ if self.qargs.weight_memory_efficient:
+ # prepare for the weight
+ with torch.no_grad():
+ weight4_s = self.prepare_weight(self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch)
+
+ return _FP8ActivationResidualQwen2AfterAttentionResidual.apply(
+ re_qx,
+ re_sx,
+ in_x,
+ self.o_proj.weight,
+ None,
+ None,
+ weight4_s,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ # prepare for the weight
+ with torch.no_grad():
+ weight4, weight4_t, weight4_s = self.prepare_weight(
+ self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch
+ )
+
+ return _FP8ActivationResidualQwen2AfterAttentionResidual.apply(
+ re_qx,
+ re_sx,
+ in_x,
+ self.o_proj.weight,
+ weight4,
+ weight4_t,
+ weight4_s,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ return re_x + self.attn_out(in_x), None, None
+
+
+class _FP8ActivationResidualQwen2AfterAttentionResidual(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ re_qx,
+ re_sx,
+ flash_x,
+ weight4_origin,
+ weight4,
+ weight4_t,
+ weight4_s,
+ group_size,
+ fwobits,
+ layer_id,
+ config,
+ qargs,
+ ):
+ # Quantize the FlashAttention Output
+ flash_qx, flash_s, _ = fp8_quantize_pertensor(
+ flash_x, group_size, fwobits["fabit"]
+ ) # Modified to make it memory efficient
+
+ # # Attention Projection Linear Layer
+ if qargs.weight_memory_efficient:
+ assert weight4 is None # memory efficient
+ weight4, weight4_s = fp8_division(weight4_origin, qargs.group_size, fwobits["fwbit"], weight4_s)
+ fc4_x = fp8_linear_forward(flash_qx, flash_s, weight4, weight4_s, False, group_size) #
+
+ # import IPython
+ # IPython.embed()
+ # Add the activations together
+ fp_x, (out_x, out_s) = fp8_add_Ig16_Ifp_Ofp_Og16(re_qx, re_sx, fc4_x, flash_qx.dtype, group_size)
+
+ # ==================== save for backward ====================
+ ctx.save_for_backward(flash_x, flash_s)
+ if qargs.weight_memory_efficient:
+ assert weight4_t is None
+ ctx.weight = weight4_origin, weight4_s
+ else:
+ ctx.weight = weight4_t, weight4_s
+ ctx.group_size = group_size
+ ctx.fwobits = fwobits
+ ctx.utils = fwobits, layer_id, config, qargs
+
+ # For autograd
+ out_x = out_x.view(torch.float8_e4m3fn)
+
+ return fp_x, out_x, out_s
+
+ @staticmethod
+ def backward(ctx, fp_grad, out_g, out_gs):
+ flash_x, flash_s = ctx.saved_tensors
+ weight4_t, weight4_s = ctx.weight
+ group_size = ctx.group_size
+ fwobits = ctx.fwobits
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ # for autograd
+ if fwobits["babit"] == "E5M2":
+ # out_g = out_g.to(torch.float8_e5m2)
+ out_g = out_g.view(torch.float8_e5m2)
+ else:
+ raise ValueError("babit should be E5M2")
+ out_gs_max = out_gs.max()
+
+ # ==================== Begin backward ====================
+ # Output Projection
+ out_g_t = fp8_transpose(out_g, transpose_output_2d=True)
+
+ # We do not save an extra flash_x to save the memory usage
+ flash_x_t, flash_s = fp8_division_transpose(
+ flash_x, group_size, fwobits["fabit"], flash_s, stochastic=False, only_transposed=True
+ )
+
+ if qargs.weight_memory_efficient:
+ weight4_t, weight4_s = fp8_division_transpose(
+ weight4_t, qargs.group_size, fwobits["fwbit"], weight4_s, only_transposed=True
+ )
+ fc4_g, attn_out_wg = fp8_linear_backward(
+ flash_x_t, flash_s, out_g, out_gs_max, out_g_t, weight4_t, weight4_s, group_size
+ )
+
+ return fp_grad, fc4_g, attn_out_wg, None, None, None, None, None, None, None, None
+
+
+class FP8ActivationResidualQwen2MLPResidual(FP8CacheWeightModule):
+ """
+ This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 2 / 3 * Linear layers
+ (4) GELU / Silu Activation
+ """
+
+ def __init__(self, config: FP8ActivationResidualQwen2Config, qargs: QuantizationConfig, layer_id, hidden_size: int):
+ super().__init__(config, qargs, layer_id)
+
+ self.qargs = qargs
+ self.fwobits = {
+ "fabit": self.qargs.fabit,
+ "fwbit": self.qargs.fwbit,
+ "fobit": self.qargs.fobit,
+ "babit": self.qargs.babit,
+ "bwbit": self.qargs.bwbit,
+ "bobit": self.qargs.bobit,
+ }
+
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.training = True
+
+ # below is only used when training = False
+ assert config.hidden_act == "silu", "We only support silu activation currently"
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, re_x, x, s, rmsnorm_weight):
+ if self.training:
+ if self.qargs.weight_memory_efficient: # prepare for the weight
+ with torch.no_grad():
+ weight1_s = self.prepare_weight(self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch)
+ weight2_s = self.prepare_weight(self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch)
+ weight3_s = self.prepare_weight(self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch)
+
+ return _FP8ActivationResidualQwen2MLPResidual.apply(
+ re_x,
+ x,
+ s,
+ self.gate_proj.weight,
+ None,
+ None,
+ weight1_s,
+ self.up_proj.weight,
+ None,
+ None,
+ weight2_s,
+ self.down_proj.weight,
+ None,
+ None,
+ weight3_s,
+ rmsnorm_weight,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ # prepare for the weight
+ with torch.no_grad():
+ weight1, weight1_t, weight1_s = self.prepare_weight(
+ self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch
+ )
+ weight2, weight2_t, weight2_s = self.prepare_weight(
+ self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch
+ )
+ weight3, weight3_t, weight3_s = self.prepare_weight(
+ self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch
+ )
+
+ return _FP8ActivationResidualQwen2MLPResidual.apply(
+ re_x,
+ x,
+ s,
+ self.gate_proj.weight,
+ weight1,
+ weight1_t,
+ weight1_s,
+ self.up_proj.weight,
+ weight2,
+ weight2_t,
+ weight2_s,
+ self.down_proj.weight,
+ weight3,
+ weight3_t,
+ weight3_s,
+ rmsnorm_weight,
+ self.qargs.group_size,
+ self.fwobits,
+ self.layer_id,
+ self.config,
+ self.qargs,
+ )
+ else:
+ raise NotImplementedError("Need TODO")
+ og_x = re_x
+ re_x = self.ff_norm(re_x)
+ re_x = self.ff_proj(re_x)
+ re_x = self.act(re_x)
+ re_x = self.ff_out(re_x)
+ re_x = og_x + re_x
+ return re_x, None, None
+
+
+class _FP8ActivationResidualQwen2MLPResidual(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ re_x,
+ in_x,
+ in_s,
+ weight1_origin,
+ weight1,
+ weight1_t,
+ weight1_s,
+ weight2_origin,
+ weight2,
+ weight2_t,
+ weight2_s,
+ weight3_origin,
+ weight3,
+ weight3_t,
+ weight3_s,
+ rmsnorm_weight,
+ group_size,
+ fwobits,
+ layer_id,
+ config,
+ qargs,
+ eps=1e-5,
+ ):
+ # For autograd
+ if fwobits["fabit"] == "E4M3":
+ # in_x = in_x.to(torch.float8_e4m3fn)
+ in_x = in_x.view(torch.float8_e4m3fn)
+ else:
+ raise ValueError("fabit should be E4M3")
+
+ # LayerNorm
+ ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward(
+ in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True
+ )
+
+ # Linear Layer of Up Projection and Gate Projection. They are fused as one linear layer.
+ if qargs.weight_memory_efficient:
+ assert weight1 is None and weight2 is None and weight3 is None # memory efficient
+ weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s)
+ weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s)
+ weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s)
+
+ gate_x, gate_s = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, True, group_size) # Gate Proj
+ up_x, up_s = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, True, group_size) # Up Proj
+
+ # silu Activation
+ silu_x, silu_s = fp8_silu_forward(gate_x, gate_s, group_size)
+
+ # Element-wise Multiplication
+ mul_x, mul_s, mul_x_t = fp8_mul_forward(silu_x, silu_s, up_x, up_s, group_size, transpose_output_2d=True)
+
+ # Output Projection
+ if weight3 is None: # memory efficient
+ weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s)
+ fc3_x = fp8_linear_forward(mul_x, mul_s, weight3, weight3_s, False, group_size)
+
+ # Add the activation together
+ out_x, out_s = fp8_add_Ifp_Ifp_Og16(re_x, fc3_x, mul_x.dtype, group_size)
+
+ # ==================== save for backward ====================
+ ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s)
+
+ ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s)
+ if (
+ qargs.weight_memory_efficient
+ ): # Weight_1/2_origin will not be saved twice, so it will be more memory efficient.
+ assert weight1_t is None and weight2_t is None and weight3_t is None
+ ctx.weight = (weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s)
+ else: # Weight1/2_t is different from the origin weight, so saving it will consumes additional memory footprint.
+ ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s)
+
+ ctx.group_size = group_size
+ ctx.ln_utils = ln_utils
+ ctx.utils = fwobits, layer_id, config, qargs
+
+ out_x = out_x.view(torch.float8_e4m3fn)
+
+ return out_x, out_s
+
+ @staticmethod
+ def backward(ctx, out_g, out_gs):
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s, silu_x, silu_s, mul_x_t, mul_s = ctx.saved_tensors
+
+ (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s) = ctx.weight
+ group_size = ctx.group_size
+ rms_weight, rstd, num_warps = ctx.ln_utils
+ fwobits, layer_id, config, qargs = ctx.utils
+
+ # For autograd
+ if fwobits["babit"] == "E5M2":
+ # out_g = out_g.to(torch.float8_e5m2)
+ out_g = out_g.view(torch.float8_e5m2)
+ else:
+ raise ValueError("babit should be E5M2")
+ out_gs_max = out_gs.max()
+
+ # ==================== Begin backward ====================
+ # Output Projection
+ out_gs = out_gs.max()
+ out_g_t = fp8_transpose(out_g, transpose_output_2d=True)
+
+ if qargs.weight_memory_efficient:
+ weight3_t, weight3_s = fp8_division_transpose(
+ weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True
+ )
+ fc3_g, weight3_grad = fp8_linear_backward(
+ mul_x_t, mul_s, out_g, out_gs_max, out_g_t, weight3_t, weight3_s, group_size
+ )
+
+ # [MEM TEST]
+ del out_g, out_g_t, weight3_t
+
+ # Element-wise Multiplication, 1 means gate, 2 means up
+ mul_g1, (mul_g2, mul_gs2, mul_g2_t) = fp8_mul_backward(
+ silu_x, silu_s, up_x, up_s, fc3_g, group_size, fwobits["babit"], output_quantized_transpose=True
+ )
+
+ # Silu activation
+ silu_g, silu_gs, silu_g_t = fp8_silu_backward(
+ gate_x, gate_s, mul_g1, group_size, fwobits["babit"], output_quantized_transpose=True
+ )
+
+ # Linear Layer of Up and Gate Projection
+ if qargs.weight_memory_efficient:
+ weight1_t, weight1_s = fp8_division_transpose(
+ weight1_t, group_size, fwobits["fwbit"], weight1_s, only_transposed=True
+ )
+ weight2_t, weight2_s = fp8_division_transpose(
+ weight2_t, group_size, fwobits["fwbit"], weight2_s, only_transposed=True
+ )
+
+ # Gate Proj
+ fc1_g, weight1_grad = fp8_linear_backward(
+ ln_x_t, ln_s, silu_g, silu_gs, silu_g_t, weight1_t, weight1_s, group_size
+ )
+ fc2_g, weight2_grad = fp8_linear_backward(
+ ln_x_t, ln_s, mul_g2, mul_gs2, mul_g2_t, weight2_t, weight2_s, group_size
+ )
+
+ fc_g = fc1_g + fc2_g
+
+ # layerNorm
+ in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc_g, rms_weight, rstd, group_size, num_warps)
+
+ # Add the gradient together
+ re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt(
+ out_g, out_gs_max, in_g, group_size, fwobits["babit"], stochastic=False
+ )
+
+ in_g = in_g.view(torch.float8_e4m3fn)
+
+ return (
+ re_g,
+ in_g,
+ in_sg_g16,
+ weight1_grad,
+ None,
+ None,
+ None,
+ weight2_grad,
+ None,
+ None,
+ None,
+ weight3_grad,
+ None,
+ None,
+ None,
+ rms_weight_grad,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+class FP8ActivationResidualQwen2AttentionWithoutLinear(nn.Module):
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
+ and "Generating Long Sequences with Sparse Transformers".
+ """
+
+ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.is_causal = True
+ self.attention_dropout = config.attention_dropout
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ self.rotary_emb = Qwen2RotaryEmbedding(config=self.config)
+
+ def forward(
+ self,
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = query_states.size()
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class FP8ActivationResidualQwen2FlashAttention2WithoutLinear(FP8ActivationResidualQwen2AttentionWithoutLinear):
+ """
+ Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
+ as the weights of the module stays untouched. The only required change would be on the forward pass
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
+ config.max_window_layers layers.
+ """
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ):
+ bsz, q_len, _ = query_states.size()
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
+ kv_seq_len = key_states.shape[-2] + cache_position[0]
+ if (
+ getattr(self.config, "sliding_window", None) is not None
+ and kv_seq_len > self.config.sliding_window
+ and cache_has_contents
+ ):
+ slicing_tokens = 1 - self.config.sliding_window
+
+ past_key = past_key_value[self.layer_idx][0]
+ past_value = past_key_value[self.layer_idx][1]
+
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
+ f" {past_key.shape}"
+ )
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, slicing_tokens:]
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
+
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ if (
+ self.config.use_sliding_window
+ and getattr(self.config, "sliding_window", None) is not None
+ and self.layer_idx >= self.config.max_window_layers
+ ):
+ sliding_window = self.config.sliding_window
+ else:
+ sliding_window = None
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=sliding_window,
+ is_causal=self.is_causal,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class FP8ActivationResidualQwen2SdpaAttentionWithoutLinear(FP8ActivationResidualQwen2AttentionWithoutLinear):
+ """
+ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from Qwen2Attention.forward
+ def forward(
+ self,
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ query_states=query_states,
+ key_states=key_states,
+ value_states=value_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ bsz, q_len, _ = query_states.size()
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and attention_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+
+ return attn_output, None, past_key_value
+
+
+FP8LINEARRESIDUALQWEN2_ATTENTION_CLASSES = {
+ "eager": FP8ActivationResidualQwen2AttentionWithoutLinear,
+ "flash_attention_2": FP8ActivationResidualQwen2FlashAttention2WithoutLinear,
+ "sdpa": FP8ActivationResidualQwen2SdpaAttentionWithoutLinear,
+}
+
+
+class FP8ActivationResidualQwen2DecoderLayer(nn.Module):
+ def __init__(self, config: FP8ActivationResidualQwen2Config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ if config.sliding_window and config._attn_implementation != "flash_attention_2":
+ logger.warning_once(
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
+ "unexpected results may be encountered."
+ )
+ self.self_attn = FP8LINEARRESIDUALQWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
+
+ self.qargs = QuantizationConfig(**config.coat_fp8_args)
+ self.BeforeAttention = FP8ActivationResidualQwen2BeforeAttentionResidual(config, self.qargs, layer_idx)
+ self.AfterAttention = FP8ActivationResidualQwen2AfterAttentionResidual(config, self.qargs, layer_idx)
+ self.MLPResidual = FP8ActivationResidualQwen2MLPResidual(config, self.qargs, layer_idx, self.hidden_size)
+
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ quant_hidden_states: torch.Tensor,
+ scale_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+
+ # Coat: The residual, LayerNorm, and the Q/K/V Projection Linear Layer
+ residual_quant, residual_scale, query_states, key_states, value_states = self.BeforeAttention(
+ quant_hidden_states, scale_hidden_states, self.input_layernorm.weight
+ )
+
+ # Self Attention without any linear layer
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ query_states=query_states,
+ key_states=key_states,
+ value_states=value_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ # Coat: The Output Projection Linear Layer and Residual
+ hidden_states, quant_hidden_states, scale_hidden_states = self.AfterAttention(
+ residual_quant, residual_scale, hidden_states
+ )
+
+ # Residual Connection, LayerNorm, and the whole MLP module
+ quant_hidden_states, scale_hidden_states = self.MLPResidual(
+ hidden_states, quant_hidden_states, scale_hidden_states, self.post_attention_layernorm.weight
+ )
+
+ outputs = ((quant_hidden_states, scale_hidden_states),)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class FP8ActivationResidualQwen2PreTrainedModel(Qwen2PreTrainedModel):
+ config_class = FP8ActivationResidualQwen2Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["FP8ActivationResidualQwen2DecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_cache_class = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class FP8ActivationResidualQwen2Model(FP8ActivationResidualQwen2PreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
+
+ Args:
+ config: Qwen2Config
+ """
+
+ def __init__(self, config: FP8ActivationResidualQwen2Config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [FP8ActivationResidualQwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self._attn_implementation = config._attn_implementation
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Qwen2RotaryEmbedding(config=config)
+
+ # Quantize
+ self.qargs = QuantizationConfig(**config.coat_fp8_args)
+ self.quantize_input_before_block = Coat_quantize_bgn(self.qargs)
+ self.quantize_output_after_block = Coat_quantize_end(self.qargs)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ if past_key_values is None:
+ past_key_values = DynamicCache()
+ else:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ # Prepare the input for Coat decoderlayer
+ quant_hidden_states, scale_hidden_states = self.quantize_input_before_block(hidden_states)
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ quant_hidden_states,
+ scale_hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ position_embeddings,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ quant_hidden_states,
+ scale_hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ # Summarize the output of the Decoder Layer
+ hidden_states = self.quantize_output_after_block(quant_hidden_states, scale_hidden_states)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
+ _update_causal_mask = Qwen2Model._update_causal_mask
+
+
+class FP8ActivationResidualQwen2ForCausalLM(FP8ActivationResidualQwen2PreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = FP8ActivationResidualQwen2Model(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ forward = Qwen2ForCausalLM.forward
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
+ prepare_inputs_for_generation = Qwen2ForCausalLM.prepare_inputs_for_generation
+
+
+AutoConfig.register("fp8activationresidual_qwen2", FP8ActivationResidualQwen2Config)
+AutoModel.register(FP8ActivationResidualQwen2Config, FP8ActivationResidualQwen2Model)
+AutoModelForCausalLM.register(FP8ActivationResidualQwen2Config, FP8ActivationResidualQwen2ForCausalLM)
+
+
+def make_state_dict_compatible(state_dict: dict[str, torch.Tensor]):
+ compatible_state_dict = {}
+
+ for key, value in state_dict.items():
+ if fnmatch(key, "*self_attn.q_proj*"):
+ new_key = key.replace("self_attn.q_proj", "BeforeAttention.q_proj")
+ elif fnmatch(key, "*self_attn.k_proj*"):
+ new_key = key.replace("self_attn.k_proj", "BeforeAttention.k_proj")
+ elif fnmatch(key, "*self_attn.v_proj*"):
+ new_key = key.replace("self_attn.v_proj", "BeforeAttention.v_proj")
+ elif fnmatch(key, "*self_attn.o_proj*"):
+ new_key = key.replace("self_attn.o_proj", "AfterAttention.o_proj")
+
+ elif fnmatch(key, "*mlp.gate_proj*"):
+ new_key = key.replace("mlp.gate_proj", "MLPResidual.gate_proj")
+ elif fnmatch(key, "*mlp.up_proj*"):
+ new_key = key.replace("mlp.up_proj", "MLPResidual.up_proj")
+ elif fnmatch(key, "*mlp.down_proj*"):
+ new_key = key.replace("mlp.down_proj", "MLPResidual.down_proj")
+
+ else:
+ new_key = key
+
+ compatible_state_dict[new_key] = value
+
+ return compatible_state_dict
diff --git a/llava/model/language_model/fp8linearqwen2.py b/llava/model/language_model/fp8linearqwen2.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ae93bb8d6798e9f327198abc405866e4a1527e2
--- /dev/null
+++ b/llava/model/language_model/fp8linearqwen2.py
@@ -0,0 +1,356 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Qwen2 model."""
+
+import math
+from functools import lru_cache
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache, StaticCache
+from transformers.modeling_attn_mask_utils import AttentionMaskConverter
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
+from transformers.models.qwen2.modeling_qwen2 import (
+ Qwen2Attention,
+ Qwen2DecoderLayer,
+ Qwen2FlashAttention2,
+ Qwen2ForCausalLM,
+ Qwen2MLP,
+ Qwen2Model,
+ Qwen2PreTrainedModel,
+ Qwen2RMSNorm,
+ Qwen2RotaryEmbedding,
+ Qwen2SdpaAttention,
+ apply_rotary_pos_emb,
+ repeat_kv,
+ rotate_half,
+)
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+ replace_return_docstrings,
+)
+
+from ..liger.cross_entropy import LigerForCausalLMLoss
+from ..qlinear_te import QLinearTE
+from .configuration_quantize import QuantizationConfig
+
+if is_flash_attn_2_available():
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
+
+logger = logging.get_logger(__name__)
+
+
+class FP8LinearQwen2Config(Qwen2Config):
+ model_type = "fp8linear_qwen2"
+
+ def __init__(
+ self,
+ coat_fp8_args=None,
+ vocab_size=151936,
+ hidden_size=4096,
+ intermediate_size=22016,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=32,
+ hidden_act="silu",
+ max_position_embeddings=32768,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ use_sliding_window=False,
+ sliding_window=4096,
+ max_window_layers=28,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_size,
+ hidden_size,
+ intermediate_size,
+ num_hidden_layers,
+ num_attention_heads,
+ num_key_value_heads,
+ hidden_act,
+ max_position_embeddings,
+ initializer_range,
+ rms_norm_eps,
+ use_cache,
+ tie_word_embeddings,
+ rope_theta,
+ rope_scaling,
+ use_sliding_window,
+ sliding_window,
+ max_window_layers,
+ attention_dropout,
+ **kwargs,
+ )
+
+ self.coat_fp8_args = coat_fp8_args
+
+
+# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
+class FP8LinearQwen2MLP(Qwen2MLP):
+ def __init__(self, config, layer_idx):
+ super().__init__(config)
+ # self.gate_proj = te.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ # self.up_proj = te.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ # self.down_proj = te.Linear(self.intermediate_size, self.hidden_size, bias=False)
+
+ self.gate_proj = QLinearTE(
+ self.hidden_size, self.intermediate_size, bias=False, args=config.coat_fp8_args, layer_idx=layer_idx
+ )
+ self.up_proj = QLinearTE(
+ self.hidden_size, self.intermediate_size, bias=False, args=config.coat_fp8_args, layer_idx=layer_idx
+ )
+ self.down_proj = QLinearTE(
+ self.intermediate_size, self.hidden_size, bias=False, args=config.coat_fp8_args, layer_idx=layer_idx
+ )
+
+ def forward(self, hidden_state):
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
+
+
+class FP8LinearQwen2Attention(Qwen2Attention):
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
+ and "Generating Long Sequences with Sparse Transformers".
+ """
+
+ def __init__(self, config: FP8LinearQwen2Config, layer_idx: Optional[int] = None):
+ super().__init__(config, layer_idx)
+
+ self.q_proj = QLinearTE(
+ self.hidden_size,
+ self.num_heads * self.head_dim,
+ bias=True,
+ args=config.coat_fp8_args,
+ layer_idx=layer_idx,
+ )
+ self.k_proj = QLinearTE(
+ self.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ bias=True,
+ args=config.coat_fp8_args,
+ layer_idx=layer_idx,
+ )
+ self.v_proj = QLinearTE(
+ self.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ bias=True,
+ args=config.coat_fp8_args,
+ layer_idx=layer_idx,
+ )
+ self.o_proj = QLinearTE(
+ self.num_heads * self.head_dim,
+ self.hidden_size,
+ bias=False,
+ args=config.coat_fp8_args,
+ layer_idx=layer_idx,
+ )
+
+ forward = Qwen2Attention.forward
+
+
+class FP8LinearQwen2FlashAttention2(FP8LinearQwen2Attention):
+ """
+ Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
+ as the weights of the module stays untouched. The only required change would be on the forward pass
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
+ config.max_window_layers layers.
+ """
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ forward = Qwen2FlashAttention2.forward
+
+
+# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2
+class FP8LinearQwen2SdpaAttention(FP8LinearQwen2Attention):
+ """
+ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from Qwen2Attention.forward
+ forward = Qwen2SdpaAttention.forward
+
+
+FP8LINEARQWEN2_ATTENTION_CLASSES = {
+ "eager": FP8LinearQwen2Attention,
+ "flash_attention_2": FP8LinearQwen2FlashAttention2,
+ "sdpa": FP8LinearQwen2SdpaAttention,
+}
+
+
+class FP8LinearQwen2DecoderLayer(nn.Module):
+ def __init__(self, config: FP8LinearQwen2Config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ if config.sliding_window and config._attn_implementation != "flash_attention_2":
+ logger.warning_once(
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
+ "unexpected results may be encountered."
+ )
+ self.self_attn = FP8LINEARQWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
+
+ self.mlp = FP8LinearQwen2MLP(config, layer_idx)
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ forward = Qwen2DecoderLayer.forward
+
+
+class FP8LinearQwen2PreTrainedModel(Qwen2PreTrainedModel):
+ config_class = FP8LinearQwen2Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["FP8LinearQwen2DecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_cache_class = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class FP8LinearQwen2Model(FP8LinearQwen2PreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
+
+ Args:
+ config: Qwen2Config
+ """
+
+ def __init__(self, config: FP8LinearQwen2Config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [FP8LinearQwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self._attn_implementation = config._attn_implementation
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Qwen2RotaryEmbedding(config=config)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ forward = Qwen2Model.forward
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
+ _update_causal_mask = Qwen2Model._update_causal_mask
+
+
+class FP8LinearQwen2ForCausalLM(FP8LinearQwen2PreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = FP8LinearQwen2Model(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @property
+ @lru_cache
+ def loss_function(self):
+ return LigerForCausalLMLoss
+
+ forward = Qwen2ForCausalLM.forward
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
+ prepare_inputs_for_generation = Qwen2ForCausalLM.prepare_inputs_for_generation
+
+
+AutoConfig.register("fp8linear_qwen2", FP8LinearQwen2Config)
+AutoModel.register(FP8LinearQwen2Config, FP8LinearQwen2Model)
+AutoModelForCausalLM.register(FP8LinearQwen2Config, FP8LinearQwen2ForCausalLM)
diff --git a/llava/model/language_model/llava_llama.py b/llava/model/language_model/llava_llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8fbe3f845abe162edd4d891cca35a752b0966da
--- /dev/null
+++ b/llava/model/language_model/llava_llama.py
@@ -0,0 +1,169 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+
+
+import os
+from collections import defaultdict
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
+from transformers.modeling_outputs import CausalLMOutputWithPast
+
+from llava.model.loss import soft_cross_entropy
+from llava.model.utils.packing import set_seqlens_in_batch
+from llava.train.sequence_parallel.globals import get_pg_manager
+from llava.utils.logging import logger
+
+from ...train.utils import calculate_loss_weight
+from ..configuration_llava import LlavaConfig
+from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel
+
+
+class LlavaLlamaConfig(LlavaConfig):
+ model_type = "llava_llama"
+
+
+# FIXME we will follow the convention to add a new class for CausalLM in the future
+class LlavaLlamaModel(LlavaMetaModel, LlavaMetaForCausalLM, PreTrainedModel):
+ config_class = LlavaLlamaConfig
+ main_input_name = "input_embeds"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn_2 = True
+
+ def __init__(self, config: LlavaLlamaConfig = None, *args, **kwargs) -> None:
+ super().__init__(config)
+ self.init_vlm(config=config, *args, **kwargs)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
+ *model_args,
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ ignore_mismatched_sizes: bool = False,
+ force_download: bool = False,
+ local_files_only: bool = False,
+ token: Optional[Union[str, bool]] = None,
+ revision: str = "main",
+ use_safetensors: bool = None,
+ **kwargs,
+ ):
+ if hasattr(cls, "load_pretrained"):
+ return cls.load_pretrained(
+ pretrained_model_name_or_path,
+ *model_args,
+ config=config,
+ cache_dir=cache_dir,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ force_download=force_download,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ use_safetensors=use_safetensors,
+ **kwargs,
+ )
+ return super(LlavaLlamaModel).from_pretrained(
+ pretrained_model_name_or_path,
+ *model_args,
+ config=config,
+ cache_dir=cache_dir,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ force_download=force_download,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ use_safetensors=use_safetensors,
+ **kwargs,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ media: Optional[Dict[str, List[torch.Tensor]]] = None,
+ images: Optional[torch.FloatTensor] = None,
+ media_config: Optional[List] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ media_meta: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ packing: bool = True,
+ force_packing: bool = False,
+ seqlens_in_batch: Optional[torch.LongTensor] = None,
+ dpo_forward: bool = False,
+ **kwargs,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ self.freezed_module_patch()
+
+ if images is not None:
+ if media is not None:
+ raise ValueError("Both 'media' and 'images' are provided. Please provide only one.")
+ logger.warning("The 'images' argument is deprecated. Please use 'media' instead.")
+ media = {"image": images}
+
+ if media_config is None:
+ media_config = defaultdict(dict)
+ if inputs_embeds is None:
+ inputs_embeds, labels, attention_mask = self._embed(input_ids, media, media_config, labels, attention_mask,media_meta)
+
+ if force_packing or (packing and self.training and not dpo_forward):
+ if seqlens_in_batch is None:
+ seqlens_in_batch = torch.sum(attention_mask, dim=1)
+ set_seqlens_in_batch(seqlens_in_batch)
+
+ (inputs_embeds, attention_mask, position_ids, labels) = self.repack_multimodal_data(
+ inputs_embeds, attention_mask, position_ids, labels
+ )
+
+ outputs = self.llm(
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ labels=labels,
+ **kwargs,
+ )
+
+ if self.training and getattr(self.config, "time_token_ids", []):
+ outputs.loss = soft_cross_entropy(
+ outputs.logits,
+ labels,
+ soft_tokens=self.config.time_token_ids,
+ std=self.config.soft_ce_std,
+ )
+
+ # Loss rescale for SP
+ if get_pg_manager() is not None:
+ loss_weight = calculate_loss_weight(labels)
+ outputs.loss = outputs.loss * loss_weight
+
+ if dpo_forward:
+ return outputs.logits, labels
+
+ return outputs
+
+
+AutoConfig.register("llava_llama", LlavaLlamaConfig)
+AutoModel.register(LlavaLlamaConfig, LlavaLlamaModel)
diff --git a/llava/model/language_model/qllama.py b/llava/model/language_model/qllama.py
new file mode 100644
index 0000000000000000000000000000000000000000..db823b94634ec5cfeb589319e947f64a473f3a30
--- /dev/null
+++ b/llava/model/language_model/qllama.py
@@ -0,0 +1,312 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch LLaMA model."""
+import math
+import os
+import time
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from flash_attn import flash_attn_func, flash_attn_varlen_func
+from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
+from transformers.activations import ACT2FN
+from transformers.modeling_flash_attention_utils import _flash_attention_forward
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.llama.configuration_llama import LlamaConfig
+from transformers.models.llama.modeling_llama import (
+ LlamaAttention,
+ LlamaDecoderLayer,
+ LlamaDynamicNTKScalingRotaryEmbedding,
+ LlamaFlashAttention2,
+ LlamaForCausalLM,
+ LlamaForSequenceClassification,
+ LlamaLinearScalingRotaryEmbedding,
+ LlamaMLP,
+ LlamaModel,
+ LlamaPreTrainedModel,
+ LlamaRMSNorm,
+ LlamaRotaryEmbedding,
+ LlamaSdpaAttention,
+ apply_rotary_pos_emb,
+ repeat_kv,
+ rotate_half,
+)
+from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+ replace_return_docstrings,
+)
+
+from ..qlinear_te import QLinearTE
+
+try:
+ import transformer_engine.pytorch as te
+except:
+ pass
+from ..qfunction import *
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "QLlamaConfig"
+
+
+class QLlamaConfig(LlamaConfig):
+ model_type = "qllama"
+
+
+class QLlamaMLP(LlamaMLP):
+ def __init__(self, config, layer_idx):
+ super().__init__(config)
+ self.layer_idx = layer_idx
+
+ # self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ # self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ # self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.gate_proj = QLinearTE(
+ self.hidden_size, self.intermediate_size, bias=False, args=config, layer_idx=layer_idx
+ )
+ self.up_proj = QLinearTE(self.hidden_size, self.intermediate_size, bias=False, args=config, layer_idx=layer_idx)
+ self.down_proj = QLinearTE(
+ self.intermediate_size, self.hidden_size, bias=False, args=config, layer_idx=layer_idx
+ )
+
+
+class QLlamaAttention(LlamaAttention):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: QLlamaConfig, layer_idx):
+ super().__init__(config)
+ self.layer_idx = layer_idx
+
+ # self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ # self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ # self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ # self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
+
+ self.q_proj = QLinearTE(
+ self.hidden_size,
+ self.num_heads * self.head_dim,
+ bias=config.attention_bias,
+ args=config,
+ layer_idx=layer_idx,
+ )
+ self.k_proj = QLinearTE(
+ self.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ bias=config.attention_bias,
+ args=config,
+ layer_idx=layer_idx,
+ )
+ self.v_proj = QLinearTE(
+ self.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ bias=config.attention_bias,
+ args=config,
+ layer_idx=layer_idx,
+ )
+ self.o_proj = QLinearTE(
+ self.num_heads * self.head_dim,
+ self.hidden_size,
+ bias=config.attention_bias,
+ args=config,
+ layer_idx=layer_idx,
+ )
+
+
+class QLlamaFlashAttention2(QLlamaAttention):
+ """
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ forward = LlamaFlashAttention2.forward
+
+
+class QLlamaSdpaAttention(QLlamaAttention):
+ """
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ forward = LlamaSdpaAttention.forward
+
+
+QLLAMA_ATTENTION_CLASSES = {
+ "eager": QLlamaAttention,
+ "flash_attention_2": QLlamaFlashAttention2,
+ "sdpa": QLlamaSdpaAttention,
+}
+
+
+class QLlamaDecoderLayer(LlamaDecoderLayer):
+ def __init__(self, config: QLlamaConfig, layer_idx):
+ super().__init__(config, layer_idx=layer_idx)
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = QLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
+
+ self.mlp = QLlamaMLP(config, layer_idx)
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.layer_idx = layer_idx
+
+ forward = LlamaDecoderLayer.forward
+
+
+class QLlamaPreTrainedModel(LlamaPreTrainedModel):
+ config_class = QLlamaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["QLlamaDecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear) or isinstance(module, QLinearTE):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class QLlamaModel(QLlamaPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
+
+ Args:
+ config: QLlamaConfig
+ """
+
+ def __init__(self, config: QLlamaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [QLlamaDecoderLayer(config, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ _update_causal_mask = LlamaModel._update_causal_mask
+ forward = LlamaModel.forward
+
+
+class QLlamaForCausalLM(QLlamaPreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = QLlamaModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.forward_step_id = 0
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ forward = LlamaForCausalLM.forward
+ prepare_inputs_for_generation = LlamaForCausalLM.prepare_inputs_for_generation
+
+
+class QLlamaForSequenceClassification(QLlamaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = QLlamaModel(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ forward = LlamaForSequenceClassification.forward
+
+
+AutoConfig.register("qllama", QLlamaConfig)
+AutoModel.register(QLlamaConfig, QLlamaModel)
+AutoModelForCausalLM.register(QLlamaConfig, QLlamaForCausalLM)
diff --git a/llava/model/language_model/qllava_qllama.py b/llava/model/language_model/qllava_qllama.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fe4f0785352cc1d9fef4b20afdfb6183672fc32
--- /dev/null
+++ b/llava/model/language_model/qllava_qllama.py
@@ -0,0 +1,119 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This file is modified from https://github.com/haotian-liu/LLaVA/
+
+
+import inspect
+import os
+import os.path as osp
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+from transformers import (
+ AutoConfig,
+ AutoModel,
+ GenerationConfig,
+ LlamaConfig,
+ LlamaForCausalLM,
+ PretrainedConfig,
+ PreTrainedModel,
+)
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.modeling_utils import ContextManagers, no_init_weights
+
+from ..configuration_llava import LlavaConfig
+from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel
+from ..utils import get_model_config, get_model_config_fp8
+from .builder import build_llm_and_tokenizer
+from .llava_llama import LlavaLlamaConfig, LlavaLlamaModel
+
+quantize_args_to_model_class = {
+ "fp8Linear_llama": "QLlamaForCausalLM",
+ "fp8LinearAndActivation_llama": "QMemLlamaForCausalLM",
+ "fp8Linear_qwen2": "FP8LinearQwen2ForCausalLM",
+ "fp8Activation_qwen2": "FP8ActivationQwen2ForCausalLM",
+ "fp8ActivationResidual_qwen2": "FP8ActivationResidualQwen2ForCausalLM",
+}
+
+
+class QLlavaLlamaConfig(LlavaLlamaConfig):
+ model_type = "qllava_qllama"
+
+
+## FIXME we will follow the convention to add a new class for CausalLM in the future
+class QLlavaLlamaModel(LlavaLlamaModel):
+ config_class = QLlavaLlamaConfig
+ main_input_name = "input_embeds"
+ supports_gradient_checkpointing = True
+
+ def __init__(self, config: QLlavaLlamaConfig = None, model_args=None, *args, **kwargs) -> None:
+ PreTrainedModel.__init__(self, config)
+ return self.init_vlm(config=config, model_args=model_args, *args, **kwargs)
+
+ # rewrite to support QLlama
+ def init_vlm(self, config: PreTrainedModel = None, model_args=None, *args, **kwargs):
+ # TODO(ligeng): figure out how from_config and from_pretrained works in HF implementation.
+ if hasattr(self, "llm") or hasattr(self, "vision_tower") or hasattr(self, "mm_projector"):
+ # already initialized, skipped
+ return
+
+ model_dtype = getattr(config, "model_dtype", "torch.float16")
+ if not hasattr(config, "model_dtype"):
+ warnings.warn("model_dtype not found in config, defaulting to torch.float16.")
+ config.model_dtype = model_dtype
+
+ if model_args.quantize_model in ["fp8Activation_qwen2", "fp8ActivationResidual_qwen2"]:
+ cfgs = get_model_config_fp8(config) # The first cfg is fp8
+ else:
+ cfgs = get_model_config(config)
+ if len(cfgs) == 3:
+ llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs
+ elif len(cfgs) == 4:
+ llm_cfg, vision_tower_cfg, mm_projector_cfg, fp8_llm_cfg = cfgs
+ kwargs.update({"fp8_llm_cfg": fp8_llm_cfg})
+ else:
+ raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.")
+
+ kwargs.update(
+ {
+ "quantize_model_class": quantize_args_to_model_class[model_args.quantize_model],
+ "model_args": model_args,
+ }
+ )
+ self.llm, self.tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs)
+
+
+ for name, module in self.llm.named_modules():
+ module.layer_name = name
+
+ self.pad_to_multiple_of = model_args.pad_to_multiple_of
+
+ self.post_config()
+ self.is_loaded = True
+
+ assert (
+ self.llm is not None or self.vision_tower is not None or self.mm_projector is not None
+ ), "At least one of the components must be instantiated."
+
+
+AutoConfig.register("qllava_qllama", QLlavaLlamaConfig)
+AutoModel.register(QLlavaLlamaConfig, QLlavaLlamaModel)
diff --git a/llava/model/language_model/qmemllama.py b/llava/model/language_model/qmemllama.py
new file mode 100644
index 0000000000000000000000000000000000000000..26e794840df2230013a66f643745e9f50ffbeef9
--- /dev/null
+++ b/llava/model/language_model/qmemllama.py
@@ -0,0 +1,679 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch LLaMA model."""
+import math
+import os
+import time
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from flash_attn import flash_attn_func, flash_attn_varlen_func
+from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache, StaticCache
+from transformers.modeling_flash_attention_utils import _flash_attention_forward
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.llama.configuration_llama import LlamaConfig
+from transformers.models.llama.modeling_llama import (
+ LlamaAttention,
+ LlamaDecoderLayer,
+ LlamaDynamicNTKScalingRotaryEmbedding,
+ LlamaFlashAttention2,
+ LlamaForCausalLM,
+ LlamaForSequenceClassification,
+ LlamaLinearScalingRotaryEmbedding,
+ LlamaMLP,
+ LlamaModel,
+ LlamaPreTrainedModel,
+ LlamaRMSNorm,
+ LlamaRotaryEmbedding,
+ LlamaSdpaAttention,
+ apply_rotary_pos_emb,
+ repeat_kv,
+ rotate_half,
+)
+from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+
+from ..qlinear_te import QLinearTE
+
+try:
+ import transformer_engine.pytorch as te
+except:
+ pass
+
+from ..quantization import QGELU, QAct_FPin, QAct_FPout, QAdd, QIdentity, QLayerNorm, QLinear, QMul
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "QMemLlamaConfig"
+
+
+class QMemLlamaConfig(LlamaConfig):
+ model_type = "qmemllama"
+
+
+class QLlamaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6, args=None, layer_type=None):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+ self.qargs = args
+ self.QAct_layernorm_in = QAct_FPout(args, layer_type=layer_type + "_in")
+ self.QAct_layernorm_out = QAct_FPin(args, layer_type=layer_type + "_out")
+
+ def forward(self, hidden_states, s):
+ hidden_states = self.QAct_layernorm_in(hidden_states, s)
+
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ hidden_states = self.weight * hidden_states.to(input_dtype)
+
+ hidden_states, s = self.QAct_layernorm_out(hidden_states)
+ return hidden_states, s
+
+
+ALL_LAYERNORM_LAYERS.append(QLlamaRMSNorm)
+
+
+class QMemLlamaMLP(LlamaMLP):
+ def __init__(self, config, layer_idx):
+ super().__init__(config)
+ self.layer_idx = layer_idx
+
+ self.gate_proj = QLinear(
+ self.hidden_size, self.intermediate_size, bias=False, args=config, layer_type="mlp_gate"
+ )
+ self.up_proj = QLinear(self.hidden_size, self.intermediate_size, bias=False, args=config, layer_type="mlp_up")
+ self.down_proj = QLinear(
+ self.intermediate_size, self.hidden_size, bias=False, args=config, layer_type="mlp_down"
+ )
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ self.QAct_act_sum = QAct_FPout(config, layer_type="mlp_act_sum")
+ self.QAct_act_gate = QAct_FPin(config, layer_type="mlp_act_gate")
+
+ self.QAct_act_up = QAct_FPin(config, layer_type="mlp_act_up")
+
+ self.QAct_act_in = QAct_FPout(config, layer_type="mlp_act_in")
+ self.QAct_act_out = QAct_FPin(config, layer_type="mlp_act_out")
+
+ self.QMul_act = QMul(config, layer_type="mul_act")
+
+ def forward(self, x, s):
+ if self.config.pretraining_tp > 1:
+ raise ValueError("Currently Quantization is not implemented for tensor parallel for simplicity")
+ slice = self.intermediate_size // self.config.pretraining_tp
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
+
+ gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
+
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
+ down_proj = [
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
+ ]
+ down_proj = sum(down_proj)
+ else:
+ # down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ x = self.QAct_act_sum(x, s)
+ x_gate, s_gate = self.QAct_act_gate(x)
+ x_up, s_up = self.QAct_act_up(x)
+ x_gate, s_gate = self.gate_proj(x_gate, s_gate)
+ x_gate = self.QAct_act_in(x_gate, s_gate)
+ x_gate = self.act_fn(x_gate)
+ x_gate, s_gate = self.QAct_act_out(x_gate)
+
+ x_up, s_up = self.up_proj(x_up, s_up)
+ x, s = self.QMul_act(x_gate, x_up, s_gate, s_up)
+ down_proj, s = self.down_proj(x, s)
+ return down_proj, s
+
+
+class QMemLlamaAttention(LlamaAttention):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: QMemLlamaConfig, layer_idx):
+ super().__init__(config)
+ self.layer_idx = layer_idx
+
+ self.q_proj = QLinear(
+ self.hidden_size,
+ self.num_heads * self.head_dim,
+ bias=config.attention_bias,
+ args=config,
+ layer_type="attn_q",
+ )
+ self.k_proj = QLinear(
+ self.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ bias=config.attention_bias,
+ args=config,
+ layer_type="attn_k",
+ )
+ self.v_proj = QLinear(
+ self.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ bias=config.attention_bias,
+ args=config,
+ layer_type="attn_v",
+ )
+ self.o_proj = QLinear(
+ self.num_heads * self.head_dim,
+ self.hidden_size,
+ bias=config.attention_bias,
+ args=config,
+ layer_type="attn_proj",
+ )
+
+ self.QAct_qkv_sum = QAct_FPout(config, layer_type="attn_qkv_sum")
+
+ self.QAct_q_in = QAct_FPin(config, layer_type="attn_q_in")
+ self.QAct_k_in = QAct_FPin(config, layer_type="attn_k_in")
+ self.QAct_v_in = QAct_FPin(config, layer_type="attn_v_in")
+
+ self.QAct_q_out = QAct_FPout(config, layer_type="attn_q_out")
+ self.QAct_k_out = QAct_FPout(config, layer_type="attn_k_out")
+ self.QAct_v_out = QAct_FPout(config, layer_type="attn_v_out")
+ self.QAct_proj_in = QAct_FPin(config, layer_type="attn_proj_in")
+
+
+class QMemLlamaFlashAttention2(QMemLlamaAttention):
+ """
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ s: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ hidden_states = self.QAct_qkv_sum(hidden_states, s)
+
+ q, sq = self.QAct_q_in(hidden_states)
+ k, sk = self.QAct_k_in(hidden_states)
+ v, sv = self.QAct_v_in(hidden_states)
+
+ query_states, sq = self.q_proj(q, sq)
+ key_states, sk = self.k_proj(k, sk)
+ value_states, sv = self.v_proj(v, sv)
+
+ query_states = self.QAct_q_out(query_states, sq)
+ key_states = self.QAct_k_out(key_states, sk)
+ value_states = self.QAct_v_out(value_states, sv)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=getattr(self, "sliding_window", None),
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+
+ attn_output = attn_output.to(torch.float32)
+ attn_output, s = self.QAct_proj_in(attn_output)
+ attn_output, s = self.o_proj(attn_output, s)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, s, attn_weights, past_key_value
+
+
+class QMemLlamaSdpaAttention(QMemLlamaAttention):
+ """
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from LlamaAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ s: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ hidden_states = self.QAct_qkv_sum(hidden_states, s)
+
+ q, sq = self.QAct_q_in(hidden_states)
+ k, sk = self.QAct_k_in(hidden_states)
+ v, sv = self.QAct_v_in(hidden_states)
+
+ query_states, sq = self.q_proj(q, sq)
+ key_states, sk = self.k_proj(k, sk)
+ value_states, sv = self.v_proj(v, sv)
+
+ query_states = self.QAct_q_out(query_states, sq)
+ key_states = self.QAct_k_out(key_states, sk)
+ value_states = self.QAct_v_out(value_states, sv)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, -1)
+
+ # attn_output = attn_output.to(torch.float32)
+ attn_output, s = self.QAct_proj_in(attn_output)
+ attn_output, s = self.o_proj(attn_output, s)
+
+ return attn_output, s, None, past_key_value
+
+
+QMemLLAMA_ATTENTION_CLASSES = {
+ "eager": QMemLlamaAttention,
+ "flash_attention_2": QMemLlamaFlashAttention2,
+ "sdpa": QMemLlamaSdpaAttention,
+}
+
+
+class QMemLlamaDecoderLayer(LlamaDecoderLayer):
+ def __init__(self, config: QMemLlamaConfig, layer_idx):
+ super().__init__(config, layer_idx=layer_idx)
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = QMemLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
+
+ self.mlp = QMemLlamaMLP(config, layer_idx)
+
+ self.input_layernorm = QLlamaRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, args=config, layer_type="ln_attn"
+ )
+ self.post_attention_layernorm = QLlamaRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, args=config, layer_type="ln_mlp"
+ )
+
+ self.QAdd_attn = QAdd(config, layer_type="add_attn")
+ self.QAdd_mlp = QAdd(config, layer_type="add_mlp")
+
+ self.QAct_reattnout_fx = QAct_FPin(config, layer_type="re_attn_out_fx")
+ self.QAct_reattnout_re = QAct_FPin(config, layer_type="re_attn_out_re")
+
+ self.QAct_remlpout_fx = QAct_FPin(config, layer_type="re_mlp_out_fx")
+ self.QAct_remlpout_re = QAct_FPin(config, layer_type="re_mlp_out_re")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+ residual, res = self.QAct_reattnout_re(hidden_states)
+ hidden_states, s = self.QAct_reattnout_fx(hidden_states)
+
+ hidden_states, s = self.input_layernorm(hidden_states, s)
+
+ # Self Attention
+ hidden_states, s, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ s=s,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = self.QAdd_attn(residual, hidden_states, res, s)
+
+ # Fully Connected
+ residual, res = self.QAct_remlpout_re(hidden_states)
+ hidden_states, s = self.QAct_remlpout_fx(hidden_states)
+
+ hidden_states, s = self.post_attention_layernorm(hidden_states, s)
+ hidden_states, s = self.mlp(hidden_states, s)
+ hidden_states = self.QAdd_mlp(residual, hidden_states, res, s)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class QMemLlamaPreTrainedModel(LlamaPreTrainedModel):
+ config_class = QMemLlamaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["QMemLlamaDecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear) or isinstance(module, QLinearTE):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class QMemLlamaModel(QMemLlamaPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
+
+ Args:
+ config: QMemLlamaConfig
+ """
+
+ def __init__(self, config: QMemLlamaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [QMemLlamaDecoderLayer(config, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ _update_causal_mask = LlamaModel._update_causal_mask
+ forward = LlamaModel.forward
+
+
+class QMemLlamaForCausalLM(QMemLlamaPreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = QMemLlamaModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.forward_step_id = 0
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ forward = LlamaForCausalLM.forward
+ prepare_inputs_for_generation = LlamaForCausalLM.prepare_inputs_for_generation
+
+
+class QMemLlamaForSequenceClassification(QMemLlamaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = QMemLlamaModel(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ forward = LlamaForSequenceClassification.forward
+
+
+AutoConfig.register("qmemllama", QMemLlamaConfig)
+AutoModel.register(QMemLlamaConfig, QMemLlamaModel)
+AutoModelForCausalLM.register(QMemLlamaConfig, QMemLlamaForCausalLM)
diff --git a/llava/model/language_model/realqmemllama.py b/llava/model/language_model/realqmemllama.py
new file mode 100644
index 0000000000000000000000000000000000000000..9eb71d015f8f7a0dcffc466b4093f67970537335
--- /dev/null
+++ b/llava/model/language_model/realqmemllama.py
@@ -0,0 +1,674 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch LLaMA model."""
+import math
+import os
+import time
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from flash_attn import flash_attn_func, flash_attn_varlen_func
+from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache, StaticCache
+from transformers.modeling_flash_attention_utils import _flash_attention_forward
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.llama.configuration_llama import LlamaConfig
+from transformers.models.llama.modeling_llama import (
+ LlamaAttention,
+ LlamaDecoderLayer,
+ LlamaDynamicNTKScalingRotaryEmbedding,
+ LlamaFlashAttention2,
+ LlamaForCausalLM,
+ LlamaForSequenceClassification,
+ LlamaLinearScalingRotaryEmbedding,
+ LlamaMLP,
+ LlamaModel,
+ LlamaPreTrainedModel,
+ LlamaRMSNorm,
+ LlamaRotaryEmbedding,
+ LlamaSdpaAttention,
+ apply_rotary_pos_emb,
+ repeat_kv,
+ rotate_half,
+)
+from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+
+from ..qlinear_te import QLinearTE
+
+try:
+ import transformer_engine.pytorch as te
+except:
+ pass
+from ..quantization import QGELU, QAct_FPin, QAct_FPout, QAdd, QIdentity, QLayerNorm, QLinear, QMul
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "QMemLlamaConfig"
+
+
+class QMemLlamaConfig(LlamaConfig):
+ model_type = "qmemllama"
+
+
+class QLlamaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6, args=None, layer_type=None):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+ self.qargs = args
+ self.QAct_layernorm_in = QAct_FPout(args, layer_type=layer_type + "_in")
+ self.QAct_layernorm_out = QAct_FPin(args, layer_type=layer_type + "_out")
+
+ def forward(self, hidden_states, s):
+ hidden_states = self.QAct_layernorm_in(hidden_states, s)
+
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ hidden_states = self.weight * hidden_states.to(input_dtype)
+
+ hidden_states, s = self.QAct_layernorm_out(hidden_states)
+ return hidden_states, s
+
+
+ALL_LAYERNORM_LAYERS.append(QLlamaRMSNorm)
+
+
+class QMemLlamaMLP(LlamaMLP):
+ def __init__(self, config, layer_idx):
+ super().__init__(config)
+ self.layer_idx = layer_idx
+ self.gate_proj = QLinear(
+ self.hidden_size, self.intermediate_size, bias=False, args=config, layer_type="mlp_gate"
+ )
+ self.up_proj = QLinear(self.hidden_size, self.intermediate_size, bias=False, args=config, layer_type="mlp_up")
+ self.down_proj = QLinear(
+ self.intermediate_size, self.hidden_size, bias=False, args=config, layer_type="mlp_down"
+ )
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ self.QAct_act_sum = QAct_FPout(config, layer_type="mlp_act_sum")
+ self.QAct_act_gate = QAct_FPin(config, layer_type="mlp_act_gate")
+ self.QAct_act_up = QAct_FPin(config, layer_type="mlp_act_up")
+
+ self.QAct_act_in = QAct_FPout(config, layer_type="mlp_act_in")
+ self.QAct_act_out = QAct_FPin(config, layer_type="mlp_act_out")
+
+ self.QMul_act = QMul(config, layer_type="mul_act")
+
+ def forward(self, x, s):
+ if self.config.pretraining_tp > 1:
+ raise ValueError("Currently Quantization is not implemented for tensor parallel for simplicity")
+ slice = self.intermediate_size // self.config.pretraining_tp
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
+
+ gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
+
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
+ down_proj = [
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
+ ]
+ down_proj = sum(down_proj)
+ else:
+ # down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ x = self.QAct_act_sum(x, s)
+ x_gate, s_gate = self.QAct_act_gate(x)
+ x_up, s_up = self.QAct_act_up(x)
+ x_gate, s_gate = self.gate_proj(x_gate, s_gate)
+ x_gate = self.QAct_act_in(x_gate, s_gate)
+ x_gate = self.act_fn(x_gate)
+ x_gate, s_gate = self.QAct_act_out(x_gate)
+
+ x_up, s_up = self.up_proj(x_up, s_up)
+ x, s = self.QMul_act(x_gate, x_up, s_gate, s_up)
+ down_proj, s = self.down_proj(x, s)
+ return down_proj, s
+
+
+class QMemLlamaAttention(LlamaAttention):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: QMemLlamaConfig, layer_idx):
+ super().__init__(config)
+ self.layer_idx = layer_idx
+ self.q_proj = QLinear(
+ self.hidden_size,
+ self.num_heads * self.head_dim,
+ bias=config.attention_bias,
+ args=config,
+ layer_type="attn_q",
+ )
+ self.k_proj = QLinear(
+ self.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ bias=config.attention_bias,
+ args=config,
+ layer_type="attn_k",
+ )
+ self.v_proj = QLinear(
+ self.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ bias=config.attention_bias,
+ args=config,
+ layer_type="attn_v",
+ )
+ self.o_proj = QLinear(
+ self.num_heads * self.head_dim,
+ self.hidden_size,
+ bias=config.attention_bias,
+ args=config,
+ layer_type="attn_proj",
+ )
+
+ self.QAct_qkv_sum = QAct_FPout(config, layer_type="attn_qkv_sum")
+
+ self.QAct_q_in = QAct_FPin(config, layer_type="attn_q_in")
+ self.QAct_k_in = QAct_FPin(config, layer_type="attn_k_in")
+ self.QAct_v_in = QAct_FPin(config, layer_type="attn_v_in")
+
+ self.QAct_q_out = QAct_FPout(config, layer_type="attn_q_out")
+ self.QAct_k_out = QAct_FPout(config, layer_type="attn_k_out")
+ self.QAct_v_out = QAct_FPout(config, layer_type="attn_v_out")
+ self.QAct_proj_in = QAct_FPin(config, layer_type="attn_proj_in")
+
+
+class QMemLlamaFlashAttention2(QMemLlamaAttention):
+ """
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ s: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ hidden_states = self.QAct_qkv_sum(hidden_states, s)
+
+ q, sq = self.QAct_q_in(hidden_states)
+ k, sk = self.QAct_k_in(hidden_states)
+ v, sv = self.QAct_v_in(hidden_states)
+
+ query_states, sq = self.q_proj(q, sq)
+ key_states, sk = self.k_proj(k, sk)
+ value_states, sv = self.v_proj(v, sv)
+
+ query_states = self.QAct_q_out(query_states, sq)
+ key_states = self.QAct_k_out(key_states, sk)
+ value_states = self.QAct_v_out(value_states, sv)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=getattr(self, "sliding_window", None),
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+
+ attn_output = attn_output.to(torch.float32)
+ attn_output, s = self.QAct_proj_in(attn_output)
+ attn_output, s = self.o_proj(attn_output, s)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, s, attn_weights, past_key_value
+
+
+class QMemLlamaSdpaAttention(QMemLlamaAttention):
+ """
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from LlamaAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ s: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ hidden_states = self.QAct_qkv_sum(hidden_states, s)
+
+ q, sq = self.QAct_q_in(hidden_states)
+ k, sk = self.QAct_k_in(hidden_states)
+ v, sv = self.QAct_v_in(hidden_states)
+
+ query_states, sq = self.q_proj(q, sq)
+ key_states, sk = self.k_proj(k, sk)
+ value_states, sv = self.v_proj(v, sv)
+
+ query_states = self.QAct_q_out(query_states, sq)
+ key_states = self.QAct_k_out(key_states, sk)
+ value_states = self.QAct_v_out(value_states, sv)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, -1)
+
+ # attn_output = attn_output.to(torch.float32)
+ attn_output, s = self.QAct_proj_in(attn_output)
+ attn_output, s = self.o_proj(attn_output, s)
+
+ return attn_output, s, None, past_key_value
+
+
+QMemLLAMA_ATTENTION_CLASSES = {
+ "eager": QMemLlamaAttention,
+ "flash_attention_2": QMemLlamaFlashAttention2,
+ "sdpa": QMemLlamaSdpaAttention,
+}
+
+
+class QMemLlamaDecoderLayer(LlamaDecoderLayer):
+ def __init__(self, config: QMemLlamaConfig, layer_idx):
+ super().__init__(config, layer_idx=layer_idx)
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = QMemLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
+
+ self.mlp = QMemLlamaMLP(config, layer_idx)
+ self.input_layernorm = QLlamaRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, args=config, layer_type="ln_attn"
+ )
+ self.post_attention_layernorm = QLlamaRMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, args=config, layer_type="ln_mlp"
+ )
+
+ self.QAdd_attn = QAdd(config, layer_type="add_attn")
+ self.QAdd_mlp = QAdd(config, layer_type="add_mlp")
+
+ self.QAct_reattnout_fx = QAct_FPin(config, layer_type="re_attn_out_fx")
+ self.QAct_reattnout_re = QAct_FPin(config, layer_type="re_attn_out_re")
+
+ self.QAct_remlpout_fx = QAct_FPin(config, layer_type="re_mlp_out_fx")
+ self.QAct_remlpout_re = QAct_FPin(config, layer_type="re_mlp_out_re")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+ residual, res = self.QAct_reattnout_re(hidden_states)
+ hidden_states, s = self.QAct_reattnout_fx(hidden_states)
+
+ hidden_states, s = self.input_layernorm(hidden_states, s)
+
+ # Self Attention
+ hidden_states, s, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ s=s,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = self.QAdd_attn(residual, hidden_states, res, s)
+
+ # Fully Connected
+ residual, res = self.QAct_remlpout_re(hidden_states)
+ hidden_states, s = self.QAct_remlpout_fx(hidden_states)
+
+ hidden_states, s = self.post_attention_layernorm(hidden_states, s)
+ hidden_states, s = self.mlp(hidden_states, s)
+ hidden_states = self.QAdd_mlp(residual, hidden_states, res, s)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class QMemLlamaPreTrainedModel(LlamaPreTrainedModel):
+ config_class = QMemLlamaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["QMemLlamaDecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear) or isinstance(module, QLinearTE):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class QMemLlamaModel(QMemLlamaPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
+
+ Args:
+ config: QMemLlamaConfig
+ """
+
+ def __init__(self, config: QMemLlamaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [QMemLlamaDecoderLayer(config, layer_idx=layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ _update_causal_mask = LlamaModel._update_causal_mask
+ forward = LlamaModel.forward
+
+
+class QMemLlamaForCausalLM(QMemLlamaPreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = QMemLlamaModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.forward_step_id = 0
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ forward = LlamaForCausalLM.forward
+ prepare_inputs_for_generation = LlamaForCausalLM.prepare_inputs_for_generation
+
+
+class QMemLlamaForSequenceClassification(QMemLlamaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = QMemLlamaModel(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ forward = LlamaForSequenceClassification.forward
+
+
+AutoConfig.register("qmemllama", QMemLlamaConfig)
+AutoModel.register(QMemLlamaConfig, QMemLlamaModel)
+AutoModelForCausalLM.register(QMemLlamaConfig, QMemLlamaForCausalLM)
diff --git a/llava/model/liger/cross_entropy.py b/llava/model/liger/cross_entropy.py
new file mode 100644
index 0000000000000000000000000000000000000000..06274f332b62ac2178904439a18c3474693242a2
--- /dev/null
+++ b/llava/model/liger/cross_entropy.py
@@ -0,0 +1,417 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+import operator
+from typing import Optional
+
+import torch
+import triton
+import triton.language as tl
+
+from .utils import compare_version, element_mul_kernel, is_hip
+
+if compare_version("triton", operator.ge, "3.0.0"):
+ try:
+ # typical import path with dispatch available
+ from triton.language.extra.libdevice import tanh
+ except ModuleNotFoundError:
+ # for working with NGC containers
+ from triton.language.extra.cuda.libdevice import tanh
+else:
+ from triton.language.math import tanh
+
+_TRUE = tl.constexpr(1)
+_FALSE = tl.constexpr(0)
+
+
+@triton.jit
+def liger_cross_entropy_kernel(
+ X_ptr,
+ X_stride,
+ Y_ptr,
+ Y_stride,
+ loss_ptr,
+ z_loss_ptr,
+ loss_stride,
+ n_cols,
+ n_non_ignore,
+ ignore_index,
+ lse_square_scale: tl.constexpr,
+ label_smoothing: tl.constexpr,
+ reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
+ softcap,
+ RETURN_Z_LOSS: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr,
+ HAS_SOFTCAPPING: tl.constexpr,
+):
+ """
+ This kernel computes both cross entropy loss and the gradient of the input.
+ We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math.
+
+ Parameters:
+ X_ptr: Pointer to input tensor.
+ X_stride (int): The stride of the input tensor.
+ Y_ptr: Pointer to target tensor.
+ Y_stride (int): The stride of the target tensor.
+ loss_ptr: Pointer to tensor to store the loss.
+ z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
+ loss_stride (int): The stride of the loss tensor.
+ n_cols (int): The number of columns in the input tensor.
+ n_non_ignore (int): The number of non-ignored elements in the batch.
+ ignore_index (int): The index to ignore in the target.
+ label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
+ lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
+ RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
+ reduction (str): The string for the reduction to apply
+ softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
+ BLOCK_SIZE (int): The block size for Triton operations.
+ HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
+ """
+
+ # https://github.com/triton-lang/triton/issues/1058
+ # If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64
+ program_id = tl.program_id(0).to(tl.int64)
+
+ # 1. Load Y_ptr first because if the target is ignore_index, we can return right away
+ Y_ptr += program_id * Y_stride
+ y = tl.load(Y_ptr)
+
+ # 2. locate the start index
+ X_ptr += program_id * X_stride
+
+ if y == ignore_index:
+ # set all X_ptr as 0
+ for i in range(0, n_cols, BLOCK_SIZE):
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
+ tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
+ return
+
+ loss_ptr += program_id * loss_stride
+ z_loss_ptr += program_id * loss_stride
+
+ # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
+ # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
+
+ # 3. [Online softmax] first pass: find max + sum
+ m = float("-inf") # m is the max value. use the notation from the paper
+ d = 0.0 # d is the sum. use the notation from the paper
+ ori_X_y = tl.load(X_ptr + y) # we need to store the original value of X_y for the loss calculation
+ if HAS_SOFTCAPPING:
+ ori_X_y = softcap * tanh(ori_X_y / softcap)
+
+ # Label smoothing is a general case of normal cross entropy
+ # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
+ scaled_x_sum = 0.0
+ eps = label_smoothing / n_cols
+
+ for i in range(0, n_cols, BLOCK_SIZE):
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
+ X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf"))
+ if HAS_SOFTCAPPING:
+ X_block = softcap * tanh(X_block / softcap)
+ block_max = tl.max(X_block)
+ if label_smoothing > 0:
+ # scale X beforehand to avoid overflow
+ scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
+ m_new = tl.maximum(m, block_max)
+ d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
+ m = m_new
+
+ # log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X)))))
+ # = log (e^(max(X)) * sum(e ^ (X_i - max(X))))
+ # = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d
+ lse = m + tl.log(d)
+
+ # 4. [Online Softmax] Second pass: compute gradients
+ # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N)
+ # dx_y = (softmax(x_y) - 1) / N
+ # dx_i = softmax(x_i) / N, i != y
+ # For label smoothing:
+ # dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y
+ # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
+ # = dx_i - (1 - label_smoothing) / N
+ # With Z loss:
+ # dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y
+ # dx_y = dx_i - (1 - label_smoothing) / N
+ # For 'sum' reduction, no normalization is applied:
+ # dx_y = softmax(x_y) - 1
+ # dx_i = softmax(x_i), for i ≠ y
+
+ for i in range(0, n_cols, BLOCK_SIZE):
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
+ X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf"))
+ if HAS_SOFTCAPPING:
+ intermediate = tanh(X_block / softcap)
+ X_block = softcap * intermediate
+ # softmax(x_i)
+ X_block = tl.exp(X_block - m) / d
+ # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
+ X_block += 2 * lse_square_scale * lse * X_block
+ # smoothing term
+ X_block += -eps
+ # special handle dx_y
+ X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
+ # reduction scale
+ if reduction == "mean":
+ X_block = X_block / (n_non_ignore)
+ # chain rule
+ # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
+ if HAS_SOFTCAPPING:
+ X_block = X_block * (1 - intermediate * intermediate)
+
+ tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
+
+ # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
+ # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
+ tl.debug_barrier()
+
+ # 5. Calculate the loss
+
+ # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
+ # = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
+ # = X_y - m - log d = X_y - lse
+ # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
+ # So we can safely calculate log (softmax(X_y)) without overflow
+ loss = lse - ori_X_y
+
+ # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
+ # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
+ # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
+ # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
+ # = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd))
+ # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
+ # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
+ # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
+ if label_smoothing > 0:
+ smooth_loss = scaled_x_sum + label_smoothing * lse
+ loss = loss * (1 - label_smoothing) + smooth_loss
+
+ # An auxiliary loss, z_loss
+ # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
+ z_loss = lse_square_scale * lse * lse
+ loss += z_loss
+ # Normalize the loss by the number of non-ignored elements if reduction is "mean"
+ if reduction == "mean":
+ z_loss = z_loss / n_non_ignore
+ loss = loss / n_non_ignore
+
+ tl.store(loss_ptr, loss)
+ if RETURN_Z_LOSS == _TRUE:
+ tl.store(z_loss_ptr, z_loss)
+
+
+# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
+# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
+# The optimal maximum block size depends on your hardware, your kernel, and your dtype
+MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
+
+
+_bool_to_return_z_loss = {
+ True: _TRUE.value,
+ False: _FALSE.value,
+}
+
+
+def cross_entropy_forward(
+ _input,
+ target,
+ ignore_index,
+ lse_square_scale,
+ label_smoothing,
+ reduction,
+ softcap,
+ return_z_loss,
+):
+ if not isinstance(return_z_loss, int):
+ assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}"
+ return_z_loss = _bool_to_return_z_loss[return_z_loss]
+ else:
+ assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}"
+
+ BT, V = _input.shape
+ n_rows = BT
+
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
+
+ # unreduced loss
+ loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
+ if return_z_loss == _TRUE.value:
+ z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
+ else:
+ z_loss_1d = loss_1d # dummy ptr when return_z_loss == False
+
+ n_non_ignore = (target != ignore_index).sum().item()
+
+ # ensure _input and target are contiguous in the last dimension
+ if _input.stride(-1) != 1:
+ _input = _input.contiguous()
+ if target.stride(-1) != 1:
+ target = target.contiguous()
+
+ # Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
+ liger_cross_entropy_kernel[(n_rows,)](
+ X_ptr=_input,
+ X_stride=_input.stride(-2),
+ Y_ptr=target,
+ Y_stride=target.stride(-1), # always 1
+ loss_ptr=loss_1d,
+ z_loss_ptr=z_loss_1d,
+ loss_stride=loss_1d.stride(-1), # always 1
+ n_cols=V,
+ n_non_ignore=n_non_ignore,
+ ignore_index=ignore_index,
+ lse_square_scale=lse_square_scale,
+ label_smoothing=label_smoothing,
+ reduction=reduction,
+ softcap=softcap if softcap is not None else 0.0,
+ RETURN_Z_LOSS=return_z_loss,
+ BLOCK_SIZE=BLOCK_SIZE,
+ HAS_SOFTCAPPING=True if softcap is not None else False,
+ # TODO: 32 seems to give the best performance
+ # Performance is quite sensitive to num_warps
+ num_warps=32 if not is_hip() else 16,
+ )
+
+ loss = torch.sum(loss_1d)
+ if return_z_loss == _TRUE.value:
+ z_loss = torch.sum(z_loss_1d)
+ else:
+ z_loss = None
+
+ return loss, z_loss, _input
+
+
+def cross_entropy_backward(_input, grad_output):
+ # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
+ pass
+
+ # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
+ # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
+ else:
+ BT, V = _input.shape
+ n_rows = BT
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
+
+ element_mul_kernel[(n_rows,)](
+ _input,
+ _input.stride(-2),
+ grad_output,
+ V,
+ BLOCK_SIZE=BLOCK_SIZE,
+ num_warps=32 if not is_hip() else 16,
+ )
+
+ return _input
+
+
+class LigerCrossEntropyFunction(torch.autograd.Function):
+ """
+ This class implements a custom autograd function for the Liger Cross Entropy loss.
+ It overrides the forward and backward methods of the torch.autograd.Function class.
+ """
+
+ @staticmethod
+ def forward(
+ ctx,
+ _input: torch.Tensor,
+ target: torch.Tensor,
+ ignore_index: int = -100,
+ lse_square_scale: float = 0.0,
+ label_smoothing: float = 0.0,
+ reduction: str = "mean",
+ softcap: Optional[float] = None,
+ return_z_loss: bool = False,
+ ):
+ """
+ The forward pass of the Liger Cross Entropy loss.
+
+ Parameters:
+ ctx : The context object.
+ _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
+ target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
+ ignore_index (int): The index to ignore in the target.
+ lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
+ label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
+ reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
+ softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
+ return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
+
+ Returns:
+ tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
+ """
+ loss, z_loss, _input = cross_entropy_forward(
+ _input,
+ target,
+ ignore_index,
+ lse_square_scale,
+ label_smoothing,
+ reduction,
+ softcap,
+ return_z_loss,
+ )
+ # TODO: investigation
+ # If we don't detach the _input tensor, the memory will double
+ # Not sure why but seems that there will be a time both grad and value exist but in different location
+ ctx.save_for_backward(_input.detach())
+ ctx.return_z_loss = return_z_loss
+
+ return loss, z_loss
+
+ @staticmethod
+ def backward(ctx, grad_output, grad_ouput2):
+ """
+ The backward pass of the Liger Cross Entropy loss.
+
+ Parameters:
+ ctx : The context object with saved tensors.
+ grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
+ grad_output2 (tenosr): No use.
+ Returns:
+ tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
+ """
+ if ctx.return_z_loss:
+ del grad_ouput2 # z_loss is only for logging
+
+ (_input,) = ctx.saved_tensors
+ _input = cross_entropy_backward(_input, grad_output)
+ return (
+ _input,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+def liger_fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs):
+ reduction = "sum" if num_items_in_batch is not None else "mean"
+ # loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
+ loss, _ = LigerCrossEntropyFunction.apply(source, target, ignore_index, 0.0, 0.0, reduction)
+ if reduction == "sum":
+ loss = loss / num_items_in_batch
+ return loss
+
+
+def LigerForCausalLMLoss(
+ logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
+):
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
+ logits = logits.float()
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+
+ # Flatten the tokens
+ shift_logits = shift_logits.view(-1, vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = liger_fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
+ return loss
diff --git a/llava/model/liger/utils.py b/llava/model/liger/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5290b98c237f43757d0ed74d91f8f3bf6a807eeb
--- /dev/null
+++ b/llava/model/liger/utils.py
@@ -0,0 +1,129 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+"""
+This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
+See the original Unsloth repository at https://github.com/unslothai/unsloth.
+
+The following line
+https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23
+is based on code from Unsloth, located at:
+https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
+
+Modifications made by Yanning Chen, 2024.
+"""
+
+import functools
+import importlib
+import operator
+from typing import Callable
+
+import torch
+import triton
+import triton.language as tl
+from packaging.version import Version
+
+
+def is_hip() -> bool:
+ return torch.version.hip is not None
+
+
+def ensure_contiguous(fn):
+ @functools.wraps(fn)
+ def wrapper(ctx, *args, **kwargs):
+ def maybe_to_contiguous(x):
+ return x.contiguous() if isinstance(x, torch.Tensor) else x
+
+ args = [maybe_to_contiguous(arg) for arg in args]
+ kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()}
+ return fn(ctx, *args, **kwargs)
+
+ return wrapper
+
+
+def calculate_settings(n):
+ # reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
+
+ MAX_FUSED_SIZE = 65536
+ BLOCK_SIZE = triton.next_power_of_2(n)
+ if BLOCK_SIZE > MAX_FUSED_SIZE:
+ raise RuntimeError(
+ f"Cannot launch Triton kernel since n = {n} exceeds "
+ f"the recommended Triton blocksize = {MAX_FUSED_SIZE}."
+ )
+
+ num_warps = 4
+ if BLOCK_SIZE >= 32768:
+ num_warps = 32 if not is_hip() else 16
+ elif BLOCK_SIZE >= 8192:
+ num_warps = 16
+ elif BLOCK_SIZE >= 2048:
+ num_warps = 8
+ return BLOCK_SIZE, num_warps
+
+
+def compare_version(package: str, operator: Callable, target: str):
+ try:
+ pkg = importlib.import_module(package)
+ except ImportError:
+ return False
+ pkg_version = Version(pkg.__version__)
+ return operator(pkg_version, Version(target))
+
+
+def get_amp_custom_fwd_bwd() -> Callable:
+ if compare_version("torch", operator.ge, "2.4.0"):
+ return (
+ functools.partial(torch.amp.custom_fwd, device_type="cuda"),
+ functools.partial(torch.amp.custom_bwd, device_type="cuda"),
+ )
+ return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
+
+
+amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd()
+
+
+torch_to_triton_dtype = {
+ torch.float32: tl.float32,
+ torch.float16: tl.float16,
+ torch.bfloat16: tl.bfloat16,
+}
+
+
+@triton.jit
+def element_mul_kernel(
+ X_ptr,
+ X_stride,
+ grad_output_ptr,
+ n_cols,
+ BLOCK_SIZE: tl.constexpr,
+):
+ """
+ This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
+ The multiplication is performed in-place on the tensor pointed by X_ptr.
+
+ Parameters:
+ X_ptr: Pointer to the input tensor.
+ X_stride (int): The stride of the input tensor.
+ grad_output_ptr: Pointer to the gradient output value.
+ n_cols (int): The number of columns in the input tensor.
+ BLOCK_SIZE (int): The block size for Triton operations.
+ """
+
+ # Get the program ID and convert it to int64 to avoid overflow
+ program_id = tl.program_id(0).to(tl.int64)
+
+ # Locate the start index
+ X_ptr += program_id * X_stride
+
+ # Load the gradient output value
+ grad_output = tl.load(grad_output_ptr)
+
+ # Perform the element-wise multiplication
+ for i in range(0, n_cols, BLOCK_SIZE):
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
+ X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
+ tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
diff --git a/llava/model/llava_arch.py b/llava/model/llava_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6030ea9c5b2f29b4a26538bfe7232fdd69e0138
--- /dev/null
+++ b/llava/model/llava_arch.py
@@ -0,0 +1,861 @@
+# Copyright (c) 2025 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
+# LICENSE is in incl_licenses directory.
+
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import json
+import logging
+import os
+import os.path as osp
+import warnings
+from abc import ABC
+from collections import OrderedDict, defaultdict, deque
+from itertools import chain
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from einops import rearrange
+from hydra.utils import instantiate
+from transformers import AutoConfig, GenerationConfig, LogitsProcessor, PreTrainedModel
+from transformers.modeling_utils import ContextManagers, no_init_weights
+
+from llava.constants import DEFAULT_SOUND_TOKEN,DEFAULT_SPEECH_TOKEN, IGNORE_INDEX, NUM_EXTRA_TOKENS
+from llava.mm_utils import process_image, process_images, process_sounds,process_sound_masks
+from llava.model.configuration_llava import LlavaConfig, ResponseFormat
+from llava.model.language_model.builder import build_llm_and_tokenizer
+from llava.model.multimodal_encoder.builder import build_sound_tower
+from llava.model.multimodal_projector.builder import build_speech_mm_projector, build_sound_mm_projector
+from llava.model.utils import get_model_config
+from llava.train.sequence_parallel import get_pg_manager
+from llava.utils import distributed
+from llava.utils.media import extract_media
+from llava.utils.tokenizer import tokenize_conversation
+
+
+class LlavaMetaModel(ABC):
+ def _init_llm(self, llm_cfg, config, *args, **kwargs):
+ llm, tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs)
+ return llm, tokenizer
+
+ def init_vlm(self, config: PreTrainedModel = None, *args, **kwargs):
+ # TODO(ligeng): figure out how from_config and from_pretrained works in HF implementation.
+ if hasattr(self, "llm") or hasattr(self, "vision_tower") or hasattr(self, "speech_tower") or hasattr(self, "sound_tower") or hasattr(self, "mm_projector") or hasattr(self, "speech_mm_projector") or hasattr(self, "sound_mm_projector"):
+ # already initialized, skipped
+ return
+
+ model_dtype = getattr(config, "model_dtype", "torch.float16")
+ if not hasattr(config, "model_dtype"):
+ warnings.warn("model_dtype not found in config, defaulting to torch.float16.")
+ config.model_dtype = model_dtype
+
+ cfgs = get_model_config(config)
+ print(cfgs)
+ if len(cfgs) == 7:
+ llm_cfg, vision_tower_cfg, speech_tower_cfg,sound_tower_cfg, mm_projector_cfg, speech_mm_projector_cfg,sound_mm_projector_cfg = cfgs
+ else:
+ raise ValueError("`llm_cfg` `mm_projector_cfg` `speech_mm_projector_cfg` `sound_mm_projector_cfg` `vision_tower_cfg` `speech_tower_cfg` `sound_tower_cfg` not found in the config.")
+
+ self.llm, self.tokenizer = self._init_llm(llm_cfg, config, *args, **kwargs)
+
+ self.sound_tower = build_sound_tower(sound_tower_cfg, config)
+ self.sound_mm_projector = build_sound_mm_projector(sound_mm_projector_cfg, config)
+
+ if isinstance(self.config, dict):
+ self.vocab_size = config.llm_cfg["vocab_size"] + NUM_EXTRA_TOKENS
+ else:
+ self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS
+ logging.info(
+ f"[XGrammar] config is not a dict, loading vocab size from tokenizer {self.tokenizer.vocab_size} + {NUM_EXTRA_TOKENS} => {self.vocab_size}"
+ )
+
+ # XGrammar tokenizer and grammar compiler
+ # lazy init only when specified json output during inference
+ self.grammar_compiler = None
+
+ self.encoders = {}
+ for name in ["sound"]:
+ config = getattr(self.config, f"{name}_encoder")
+ if isinstance(config, str):
+ config = json.loads(config)
+ self.encoders[name] = instantiate(config, parent=self)
+
+ self.post_config()
+ self.is_loaded = True
+
+ assert (
+ self.llm is not None or self.vision_tower is not None or self.speech_tower is not None or self.mm_projector is not None or self.speech_mm_projector is not None
+ ), "At least one of the components must be instantiated."
+
+ @classmethod
+ def load_from_config(cls, model_path_or_config, *args, **kwargs):
+ pass
+
+ ## FIXME we will use this function to load model in the future
+ @classmethod
+ def load_pretrained(cls, model_path_or_config, *args, **kwargs):
+ kwargs.pop("config", None)
+
+ if isinstance(model_path_or_config, str):
+ config = AutoConfig.from_pretrained(model_path_or_config)
+ elif isinstance(model_path_or_config, LlavaConfig):
+ config = model_path_or_config
+ else:
+ raise NotImplementedError(
+ f"wrong type, {type(model_path_or_config)} \
+ {isinstance(model_path_or_config, LlavaConfig)}"
+ )
+
+ model_dtype = getattr(config, "model_dtype", "torch.float16")
+ if not hasattr(config, "model_dtype"):
+ warnings.warn("model_dtype not found in config, defaulting to torch.float16.")
+ config.model_dtype = model_dtype
+
+ cfgs = get_model_config(config)
+ if len(cfgs) == 7:
+ llm_cfg, vision_tower_cfg, speech_tower_cfg,sound_tower_cfg, mm_projector_cfg, speech_mm_projector_cfg,sound_mm_projector_cfg = cfgs
+ else:
+ raise ValueError("`llm_cfg` `mm_projector_cfg` `speech_mm_projector_cfg` `sound_mm_projector_cfg` `vision_tower_cfg` `speech_tower_cfg` `sound_tower_cfg` not found in the config.")
+
+ init_context = [
+ no_init_weights(_enable=True),
+ ]
+
+ with ContextManagers(init_context):
+ vlm = cls(config, *args, **kwargs)
+
+ if hasattr(vlm, "llm") or hasattr(vlm, "vision_tower") or hasattr(vlm, "speech_tower") or hasattr(vlm, "sound_tower") or hasattr(vlm, "mm_projector") or hasattr(vlm, "speech_mm_projector") or hasattr(vlm, "sound_mm_projector"):
+ if vlm.is_loaded:
+ return vlm
+
+ vlm.llm, vlm.tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs)
+ vlm.sound_tower = build_sound_tower(sound_tower_cfg, config)
+ vlm.sound_mm_projector = build_sound_mm_projector(sound_mm_projector_cfg, config)
+
+ self.post_config()
+ self.is_loaded = True
+
+ # FIXME(ligeng, yunhao): llm should never be none here.
+ assert (
+ vlm.llm is not None or vlm.vision_tower is not None or vlm.speech_tower is not None or vlm.mm_projector is not None or vlm.speech_mm_projector is not None
+ ), "At least one of the components must be instantiated."
+ return vlm
+
+ ## FIXME we will use this function to save the model in the future
+ def save_pretrained(self, output_dir, state_dict=None):
+ if state_dict is None:
+ # other wise fetch from deepspeed
+ # state_dict = accelerator.get_state_dict(is_deepspeed_enabled)
+ state_dict = self.state_dict()
+
+ if getattr(self, "tokenizer", None):
+ self.tokenizer.save_pretrained(osp.join(output_dir, "llm"))
+
+ if self.get_llm():
+ print(f"saving llm to {osp.join(output_dir, 'llm')}")
+ self.llm.config._name_or_path = osp.join(output_dir, "llm")
+ llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k})
+ self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict)
+ self.config.llm_cfg = self.llm.config
+
+
+ if self.get_sound_tower():
+ print(f"saving sound_tower to {osp.join(output_dir, 'sound_tower')}")
+ self.sound_tower.config._name_or_path = osp.join(output_dir, "sound_tower")
+ sound_tower_state_dict = OrderedDict(
+ {k.split("sound_tower.sound_tower.")[-1]: v for k, v in state_dict.items() if "sound_tower" in k}
+ )
+ self.sound_tower.sound_tower.save_pretrained(
+ os.path.join(output_dir, "sound_tower"),
+ state_dict=sound_tower_state_dict,
+ )
+ self.config.sound_tower_cfg = self.sound_tower.config
+
+ if self.get_sound_mm_projector():
+ print(f"saving sound_mm_projector to {osp.join(output_dir, 'sound_mm_projector')}")
+ self.sound_mm_projector.config._name_or_path = osp.join(output_dir, "sound_mm_projector")
+ sound_mm_projector_state_dict = OrderedDict(
+ {k.split("sound_mm_projector.")[-1]: v for k, v in state_dict.items() if "sound_mm_projector" in k}
+ )
+ self.sound_mm_projector.save_pretrained(
+ os.path.join(output_dir, "sound_mm_projector"),
+ state_dict=sound_mm_projector_state_dict,
+ )
+ self.config.sound_mm_projector_cfg = self.sound_mm_projector.config
+
+ ## update and save top-level config
+ self.config._name_or_path = output_dir
+ self.config.architectures = [self.__class__.__name__]
+ self.config.save_pretrained(output_dir)
+
+ def get_llm(self):
+ llm = getattr(self, "llm", None)
+ if type(llm) is list:
+ llm = llm[0]
+ return llm
+
+ def get_lm_head(self):
+ lm_head = getattr(self.get_llm(), "lm_head", None)
+ return lm_head
+
+ def get_sound_tower(self):
+ sound_tower = getattr(self, "sound_tower", None)
+ if type(sound_tower) is list:
+ sound_tower = sound_tower[0]
+ return sound_tower
+
+
+ def get_sound_mm_projector(self):
+ sound_mm_projector = getattr(self, "sound_mm_projector", None)
+ if type(sound_mm_projector) is list:
+ sound_mm_projector = sound_mm_projector[0]
+ return sound_mm_projector
+
+ def post_config(self):
+ self.training = self.get_llm().training
+ ## configuration
+ if getattr(self.config, "llm_cfg", None) is None:
+ self.config.llm_cfg = self.llm.config
+ self.config.speech_tower_cfg = self.speech_tower.config
+ if getattr(self.config, "sound_tower_cfg", None) is None:
+ self.config.sound_tower_cfg = self.sound_tower.config
+ if getattr(self.config, "sound_mm_projector_cfg", None) is None:
+ self.config.sound_mm_projector_cfg = self.sound_mm_projector.config
+
+ def freezed_module_patch(self):
+ """
+ Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules.
+ """
+ if self.training:
+ if self.get_llm() and not getattr(self.config, "tune_language_model", False):
+ pass
+
+ if self.get_sound_tower() and not getattr(self.config, "tune_sound_tower", False):
+ self.get_sound_tower().eval()
+ if self.get_sound_mm_projector() and not getattr(self.config, "tune_sound_mm_projector", False):
+ self.get_sound_mm_projector().eval()
+
+
+ def encode_sound(self, sounds, masks=None):
+
+ sound_features = self.get_sound_tower()(sounds, masks)
+ sound_features = self.get_sound_mm_projector()(sound_features)
+ return sound_features
+
+ ## @yunhao: is there a better way to handle function call and attributes for llm?
+ ## support beam search
+ def _temporary_reorder_cache(self, past_key_values, sorted_idx):
+ return self.get_llm()._temporary_reorder_cache(past_key_values, sorted_idx)
+
+ def get_input_embeddings(self):
+ return self.get_llm().get_input_embeddings()
+
+ def get_output_embeddings(self):
+ return self.get_llm().get_output_embeddings()
+
+ def resize_token_embeddings(self, embed_size):
+ self.get_llm().resize_token_embeddings(embed_size)
+
+
+class LlavaMetaForCausalLM(ABC):
+ def _embed(
+ self,
+ input_ids: torch.Tensor,
+ media: Dict[str, List[torch.Tensor]],
+ media_config: Dict[str, Dict[str, Any]],
+ labels: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ media_meta: Dict[str, Dict[str, Any]]= None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ labels = labels if labels is not None else torch.full_like(input_ids, IGNORE_INDEX)
+ attention_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids, dtype=torch.bool)
+
+ PROCESS_GROUP_MANAGER = get_pg_manager()
+ if PROCESS_GROUP_MANAGER is not None:
+ for name in media:
+ self.encoders[name].end_tokens = None
+
+ # Extract text and media embeddings
+ text_embeds = self.llm.model.embed_tokens(input_ids)
+ media_embeds = self.__embed_media_tokens(media, media_config, media_meta)
+ # This is a workaround to make sure the dummy embeddings are consumed
+ while media_embeds.get("dummy"):
+ dummy_embed = media_embeds["dummy"].popleft()
+ text_embeds += torch.sum(dummy_embed) * 0
+ # Remove padding
+ batch_size = labels.shape[0]
+
+ # Build inverse mapping from token ID to media name
+ media_tokens = {}
+ for name, token_id in self.tokenizer.media_token_ids.items():
+ media_tokens[token_id] = name
+
+ # -------------------------------- #
+ num_audio_tokens = torch.stack(media_meta["sound_embed_masks"], dim=0).sum(-1)
+ num_audio_tokens = torch.tensor([round(int(x) / 10) * 10 for x in num_audio_tokens])
+ num_audios = len(media_embeds['sound']) # length of queue is the number of audios we have in total
+ max_audio_tokens, embed_dim = media_embeds['sound'][0].shape
+
+
+ audio_features_mask = torch.arange(max_audio_tokens).expand(num_audios, max_audio_tokens).to(
+ num_audio_tokens.device
+ ) < num_audio_tokens.unsqueeze(1)
+
+ audio_embeds = []
+ while media_embeds['sound']:
+ audio_embeds.append(media_embeds['sound'].popleft())
+ audio_embeds = torch.stack(audio_embeds,dim=0)
+
+ masked_audio_features = audio_embeds[audio_features_mask].view(-1, embed_dim)
+ batch_size, sequence_length = input_ids.shape
+ _left_padding = torch.any(attention_mask[:, 0] == 0)
+ _right_padding = torch.any(attention_mask[:, -1] == 0)
+
+ left_padding = True
+ if batch_size > 1:
+ if _left_padding and not _right_padding:
+ left_padding = True
+ elif not _left_padding and _right_padding:
+ left_padding = False
+ elif not _left_padding and not _right_padding:
+ # both side is 1, so cannot tell
+ left_padding = self.tokenizer.padding_side == "left"
+ else:
+ # invalid attention_mask
+ raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
+
+ # 1. Create a mask to know where special audio tokens are
+ special_audio_token_mask = input_ids == self.tokenizer.media_token_ids['sound'] #hard coded to just work with 'sound'
+ num_special_audio_tokens = torch.sum(special_audio_token_mask, dim=-1)
+
+ # In case the Audio model or the Language model has been offloaded to CPU, we need to manually
+ # set the corresponding tensors into their correct target device.
+ target_device = text_embeds.device
+ attention_mask = attention_mask.to(target_device)
+ input_ids = input_ids.to(target_device)
+ num_audio_tokens = num_audio_tokens.to(target_device)
+ batch_indices, non_audio_indices = torch.where(
+ (input_ids != self.tokenizer.media_token_ids['sound']) & (attention_mask == 1)
+ )
+
+ # 2. Compute the positions where text should be written
+ # Calculate new positions for text tokens in merged audio-text sequence.
+ # `special_audio_token_mask` identifies audio tokens. Each audio token will be replaced by `audio_feat_lengths - 1` text tokens.
+ # `torch.cumsum` computes how each audio token shifts subsequent text token positions.
+ token_placeholder_num = torch.zeros_like(input_ids)
+ token_placeholder_num[special_audio_token_mask] = num_audio_tokens.long() - 1
+ token_placeholder_num = token_placeholder_num + 1
+ new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1
+ max_token_num = token_placeholder_num.sum(-1).max()
+ nb_audio_pad = max_token_num - 1 - new_token_positions[:, -1]
+ if left_padding:
+ new_token_positions += nb_audio_pad[:, None] # offset for left padding
+ text_to_overwrite = new_token_positions[batch_indices, non_audio_indices]
+ batch_indices, non_audio_indices, text_to_overwrite = (
+ batch_indices.to(target_device),
+ non_audio_indices.to(target_device),
+ text_to_overwrite.to(target_device),
+ )
+
+ # 3. Create the full embedding, already padded to the maximum position
+ final_embedding = torch.zeros(
+ batch_size, max_token_num, embed_dim, dtype=text_embeds.dtype, device=text_embeds.device
+ )
+ final_attention_mask = torch.zeros(
+ batch_size, max_token_num, dtype=attention_mask.dtype, device=text_embeds.device
+ )
+ final_input_ids = torch.full(
+ (batch_size, max_token_num), self.tokenizer.pad_token_id, dtype=input_ids.dtype, device=text_embeds.device
+ )
+
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "